From a1ab52fe15c5a9f20969aff5dce0060f6831523d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 6 Dec 2020 23:21:39 -0800 Subject: [PATCH 01/24] Start of Indexing api Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 16 ++ .../main/java/org/tensorflow/op/Index.java | 188 ++++++++++++++++++ .../java/org/tensorflow/op/core/Indexing.java | 130 ++++++++++++ .../test/java/org/tensorflow/IndexTest.java | 43 ++++ .../org/tensorflow/op/core/IndexingTest.java | 54 +++++ 5 files changed, 431 insertions(+) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index cf7c5b47030..70fd4ff99e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -93,6 +93,7 @@ import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.IdentityN; import org.tensorflow.op.core.ImmutableConst; +import org.tensorflow.op.core.Indexing; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.InitializeTable; import org.tensorflow.op.core.InitializeTableFromTextFile; @@ -5900,6 +5901,13 @@ public StopGradient stopGradient(Operand input) { return StopGradient.create(scope, input); } + /** + * empty + */ + public StridedSlice stridedSlice(Operand input, Index... indices) { + return Indexing.stridedSlice(scope, input, indices); + } + /** * Return a strided slice from `input`. *

@@ -6012,6 +6020,14 @@ public StridedSlice stridedSlice(Operand return StridedSlice.create(scope, input, begin, end, strides, options); } + /** + * empty + */ + public StridedSliceAssign stridedSliceAssign(Operand ref, + Operand value, Index... indices) { + return Indexing.stridedSliceAssign(scope, ref, value, indices); + } + /** * Assign `value` to the sliced l-value reference of `ref`. *

diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java new file mode 100644 index 00000000000..3e45b63010b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java @@ -0,0 +1,188 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op; + +import org.tensorflow.op.annotation.Endpoint; + +public abstract class Index { + + private final int begin; + private final int end; + private final int stride; + private final boolean beginMask; + private final boolean endMask; + private final boolean ellipsisMask; + private final boolean newAxisMask; + private final boolean shrinkAxisMask; + + public Index(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, + boolean newAxisMask, boolean shrinkAxisMask) { + this.begin = begin; + this.end = end; + this.stride = stride; + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + } + + public int getBegin() { + return begin; + } + + public int getEnd() { + return end; + } + + public int getStride() { + return stride; + } + + public boolean isBeginMask() { + return beginMask; + } + + public boolean isEndMask() { + return endMask; + } + + public boolean isEllipsisMask() { + return ellipsisMask; + } + + public boolean isNewAxisMask() { + return newAxisMask; + } + + public boolean isShrinkAxisMask() { + return shrinkAxisMask; + } + + + public static abstract class Singular extends Index { + + public Singular(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, + boolean newAxisMask, boolean shrinkAxisMask) { + super(begin, end, stride, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + } + + public static class All extends Singular { + + public All() { + super(0, 0, 1, true, true, false, false, false); + } + } + + public static All all(){ + return new All(); + } + + public static class Point extends Singular { + + private final int index; + + public Point(int index, boolean keepDim) { + super(index, index + 1, 1, false, false, false, false, !keepDim); + this.index = index; + } + + public int getIndex() { + return index; + } + } + + public static Point point(int index){ + return new Point(index, false); + } + + public static Point point(int index, boolean keepDim){ + return new Point(index, keepDim); + } + + public static class NewAxis extends Index { + + public NewAxis() { + super(0, 0, 1, false, false, false, true, false); + } + } + + public static NewAxis newAxis(){ + return new NewAxis(); + } + + public static class Ellipses extends Index { + + public Ellipses() { + super(0, 0, 1, false, false, true, false, false); + } + } + + public static Ellipses ellipses(){ + return new Ellipses(); + } + + public static class Slice extends Index { + + public Slice(Singular start, Singular end, int stride) { + super(start instanceof Point ? ((Point) start).index : 0, + end instanceof Point ? ((Point) end).index : 0, + stride, + start instanceof All, + end instanceof All, + false, + false, + false); + + if(stride < 1){ + throw new IllegalArgumentException("Can not have a stride < 1"); + } + } + } + + public static Slice slice(Singular start, Singular end, int stride){ + return new Slice(start == null ? all() : start, end == null ? all() : end, stride); + } + + public static Slice slice(int start, Singular end, int stride){ + return slice(point(start), end, stride); + } + + public static Slice slice(Singular start, int end, int stride){ + return slice(start, point(end), stride); + } + + public static Slice slice(int start, int end, int stride){ + return slice(point(start), point(end), stride); + } + + public static Slice slice(Singular start, Singular end){ + return slice(start, end, 1); + } + + public static Slice slice(int start, Singular end){ + return slice(start, end, 1); + } + + public static Slice slice(Singular start, int end){ + return slice(start, end, 1); + } + + public static Slice slice(int start, int end){ + return slice(start, end, 1); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java new file mode 100644 index 00000000000..b9c4f37a27d --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java @@ -0,0 +1,130 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import org.tensorflow.Operand; +import org.tensorflow.op.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +@Operator +public class Indexing { + + static class StridedSliceArgs { + + final int[] begin; + final int[] end; + final int[] strides; + final long beginMask; + final long endMask; + final long ellipsisMask; + final long newAxisMask; + final long shrinkAxisMask; + + public StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask, + long newAxisMask, long shrinkAxisMask) { + this.begin = begin; + this.end = end; + this.strides = strides; + this.beginMask = beginMask; + this.endMask = endMask; + this.ellipsisMask = ellipsisMask; + this.newAxisMask = newAxisMask; + this.shrinkAxisMask = shrinkAxisMask; + } + } + + static StridedSliceArgs mergeIndexes(Index[] indices) { + int[] begin = new int[indices.length]; + int[] end = new int[indices.length]; + int[] strides = new int[indices.length]; + long beginMask = 0; + long endMask = 0; + long ellipsisMask = 0; + long newAxisMask = 0; + long shrinkAxisMask = 0; + + for (int i = 0; i < indices.length; i++) { + Index idx = indices[i]; + if (idx == null) { + idx = Index.all(); + } + + begin[i] = idx.getBegin(); + end[i] = idx.getEnd(); + strides[i] = idx.getStride(); + + if (idx.isBeginMask()) { + beginMask |= 1L << i; + } + + if (idx.isEndMask()) { + endMask |= 1L << i; + } + + if (idx.isEllipsisMask()) { + ellipsisMask |= 1L << i; + } + + if (idx.isNewAxisMask()) { + newAxisMask |= 1L << i; + } + + if (idx.isShrinkAxisMask()) { + shrinkAxisMask |= 1L << i; + } + } + + return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + } + + @Endpoint(name = "stridedSlice") + public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSlice.create( + scope, + input, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + StridedSlice.beginMask(args.beginMask), + StridedSlice.endMask(args.endMask), + StridedSlice.ellipsisMask(args.ellipsisMask), + StridedSlice.newAxisMask(args.newAxisMask), + StridedSlice.shrinkAxisMask(args.shrinkAxisMask) + ); + } + + @Endpoint(name = "stridedSliceAssign") + public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, Index... indices) { + StridedSliceArgs args = mergeIndexes(indices); + return StridedSliceAssign.create( + scope, + ref, + Constant.vectorOf(scope, args.begin), + Constant.vectorOf(scope, args.end), + Constant.vectorOf(scope, args.strides), + value, + StridedSliceAssign.beginMask(args.beginMask), + StridedSliceAssign.endMask(args.endMask), + StridedSliceAssign.ellipsisMask(args.ellipsisMask), + StridedSliceAssign.newAxisMask(args.newAxisMask), + StridedSliceAssign.shrinkAxisMask(args.shrinkAxisMask) + ); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java new file mode 100644 index 00000000000..18a71ab9f98 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java @@ -0,0 +1,43 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.Test; +import org.tensorflow.op.Index; + +public class IndexTest { + @Test + public void testNullConversions(){ + assertTrue(Index.slice(null, 0).isBeginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Index.slice(null, Index.point(0)).isBeginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Index.slice(null, null).isBeginMask(), + "Passed null for slice start but didn't set begin mask"); + + assertTrue(Index.slice(0, null).isEndMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Index.slice(Index.point(0), null).isEndMask(), + "Passed null for slice end but didn't set end mask"); + + assertTrue(Index.slice(null, null).isEndMask(), + "Passed null for slice end but didn't set end mask"); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java new file mode 100644 index 00000000000..f391de4cd66 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +package org.tensorflow.op.core; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import org.junit.Test; +import org.tensorflow.op.Index; + +public class IndexingTest { + + @Test + public void testIndexMerge() { + Indexing.StridedSliceArgs args = Indexing.mergeIndexes(new Index[]{ + Index.point(2), + Index.point(1, true), + Index.all(), + Index.newAxis(), + Index.ellipses(), + Index.slice(Index.all(), 10), + Index.slice(10, null, 2) + } + ); + + assertArrayEquals(args.begin, new int[]{2, 1, 0, 0, 0, 0, 10}); + assertArrayEquals(args.end, new int[]{3, 2, 0, 0, 0, 10, 0}); + assertArrayEquals(args.strides, new int[]{1, 1, 1, 1, 1, 1, 2}); + assertEquals(args.beginMask, 0b0100100); + assertEquals(args.endMask, 0b1000100); + assertEquals(args.ellipsisMask, 0b0010000); + assertEquals(args.newAxisMask, 0b0001000); + assertEquals(args.shrinkAxisMask, 0b0000001); + + } + + @Test + public void testStridedSliceIndex(){ + //TODO test op + } + +} From 74ba9b0f384f716989846bd993a831615cadba99 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 14:31:57 -0800 Subject: [PATCH 02/24] Add javadoc Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/op/Index.java | 239 ++++++++++++++++-- .../java/org/tensorflow/op/core/Indexing.java | 77 +++++- 2 files changed, 287 insertions(+), 29 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java index 3e45b63010b..82b081b7294 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java @@ -16,6 +16,22 @@ import org.tensorflow.op.annotation.Endpoint; +/** + * Numpy-like indexing. Supports slices, stride slices, open slices, all, newaxis, and ellipsis. + *

+ * Examples: + *

{@code
+ * x[1:-1, :, tf.newaxis, ...]
+ * // becomes
+ * stridedSlice(x, Index.slice(1, -1), Index.all(), Index.newAxis(), Index.ellipsis())
+ *
+ *
+ * x[2, 10:, :-10, 2:-2:2]
+ * // becomes
+ * stridedSlice(x, Index.point(2), Index.slice(10, null), Index.slice(null, -10), Index.slice(2, -2, 2))
+ * }
+ * + */ public abstract class Index { private final int begin; @@ -27,7 +43,7 @@ public abstract class Index { private final boolean newAxisMask; private final boolean shrinkAxisMask; - public Index(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, + private Index(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, boolean newAxisMask, boolean shrinkAxisMask) { this.begin = begin; this.end = end; @@ -39,63 +55,91 @@ public Index(int begin, int end, int stride, boolean beginMask, boolean endMask, this.shrinkAxisMask = shrinkAxisMask; } + /** + * @return the beginning index of the slice. + */ public int getBegin() { return begin; } + /** + * @return the end (exclusive) index of the slice. + */ public int getEnd() { return end; } + /** + * @return the stride of the slice. + */ public int getStride() { return stride; } + /** + * @return whether to begin at the beginning. + */ public boolean isBeginMask() { return beginMask; } + /** + * @return whether to end at the end. + */ public boolean isEndMask() { return endMask; } + /** + * @return is this index an {@link Ellipses} + */ public boolean isEllipsisMask() { return ellipsisMask; } + /** + * @return should this index add a new dimension. + */ public boolean isNewAxisMask() { return newAxisMask; } + /** + * @return should this index shrink its dimension. + */ public boolean isShrinkAxisMask() { return shrinkAxisMask; } - + /** + * An index that can be used as the start or end of a slice. + */ public static abstract class Singular extends Index { - public Singular(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, + private Singular(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, boolean newAxisMask, boolean shrinkAxisMask) { super(begin, end, stride, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } } + /** + * An index that gets the entire dimension. + */ public static class All extends Singular { - public All() { + private All() { super(0, 0, 1, true, true, false, false, false); } } - public static All all(){ - return new All(); - } - + /** + * An index that gets a single point from its dimension, collapsing it by default. + */ public static class Point extends Singular { private final int index; - public Point(int index, boolean keepDim) { + private Point(int index, boolean keepDim) { super(index, index + 1, 1, false, false, false, false, !keepDim); this.index = index; } @@ -105,39 +149,32 @@ public int getIndex() { } } - public static Point point(int index){ - return new Point(index, false); - } - - public static Point point(int index, boolean keepDim){ - return new Point(index, keepDim); - } - + /** + * An index that adds a new dimension of size 1 where it is used. + */ public static class NewAxis extends Index { - public NewAxis() { + private NewAxis() { super(0, 0, 1, false, false, false, true, false); } } - public static NewAxis newAxis(){ - return new NewAxis(); - } - + /** + * An index that expands to get all possible dimensions. + */ public static class Ellipses extends Index { - public Ellipses() { + private Ellipses() { super(0, 0, 1, false, false, true, false, false); } } - public static Ellipses ellipses(){ - return new Ellipses(); - } - + /** + * An index to get a slice of its dimension, with an optional stride. + */ public static class Slice extends Index { - public Slice(Singular start, Singular end, int stride) { + private Slice(Singular start, Singular end, int stride) { super(start instanceof Point ? ((Point) start).index : 0, end instanceof Point ? ((Point) end).index : 0, stride, @@ -153,34 +190,180 @@ public Slice(Singular start, Singular end, int stride) { } } + /** + * An index that gets the entire dimension. + *

+ * Equivalent to Python's {@code :}. + */ + public static All all(){ + return new All(); + } + + /** + * An index that gets the value at the given position, and collapses the dimension (removing it). + *

+ * Equivalent to Python's indexing, and supports negative values in the same way. + * + * @param index The position to get. + */ + public static Point point(int index){ + return new Point(index, false); + } + + + /** + * An index that gets the value at the given position, and collapses the dimension (removing it) if keepDim is false. + * + * @param index The position to get. + * @param keepDim Whether to keep the dimension as size 1. + */ + public static Point point(int index, boolean keepDim){ + return new Point(index, keepDim); + } + + /** + * An index that adds a new dimension of size 1 where it is used. + *

+ * Equivalent to Python's {@code np.newaxis}, {@code tf.newaxis} or {@code None}. + */ + public static NewAxis newAxis(){ + return new NewAxis(); + } + + /** + * An index that expands to get all possible dimensions. + *

+ * Equivalent to Python's {@code ...}. + */ + public static Ellipses ellipses(){ + return new Ellipses(); + } + + /** + * An index to get a slice of its dimension. + * Start and end can be null or All to start or end at the beginning or end, respectively. + *

+ * Equivalent to Python's {@code :} slicing syntax: + *

{@code
+   * :
+   * // becomes
+   * Index.all()
+   * Index.slice(null, null)
+   * Index.slice(Index.all(), Index.all())
+   *
+   * 2:
+   * // becomes
+   * Index.slice(2, null)
+   *
+   * :2
+   * // becomes
+   * Index.slice(null, 2)
+   *
+   * 2:10
+   * // becomes
+   * Index.slice(2, 10)
+   *
+   * :2
+   * // becomes
+   * Index.slice(null, null, 2)
+   *
+   * 2:10:2
+   * //becomes
+   * Index.slice(2, 10, 2)
+   * }
+ * + * @param start Where to start the slice. Starts at the beginning if null or All. + * @param end Where to end the slice (exclusive). Ends at the end if null or All. + * @param stride The stride. + */ public static Slice slice(Singular start, Singular end, int stride){ return new Slice(start == null ? all() : start, end == null ? all() : end, stride); } + + /** + * An index to get a slice of its dimension. + * End can be null or All to end at the end. + * + * @param start Where to start the slice. + * @param end Where to end the slice (exclusive). Ends at the end if null or All. + * @param stride The stride. + * @see #slice(Singular, Singular, int) + */ public static Slice slice(int start, Singular end, int stride){ return slice(point(start), end, stride); } + + /** + * An index to get a slice of its dimension. + * Start can be null or All to start at the beginning. + * + * @param start Where to start the slice. Starts at the beginning if null or All. + * @param end Where to end the slice (exclusive). + * @param stride The stride. + * @see #slice(Singular, Singular, int) + */ public static Slice slice(Singular start, int end, int stride){ return slice(start, point(end), stride); } + + /** + * An index to get a slice of its dimension. + * + * @param start Where to start the slice. + * @param end Where to end the slice (exclusive). + * @param stride The stride. + * @see #slice(Singular, Singular, int) + */ public static Slice slice(int start, int end, int stride){ return slice(point(start), point(end), stride); } + /** + * An index to get a slice of its dimension. + * Start and end can be null or All to start or end at the beginning or end, respectively. + * + * @param start Where to start the slice. Starts at the beginning if null or All. + * @param end Where to end the slice (exclusive). Ends at the end if null or All. + * @see #slice(Singular, Singular, int) + */ public static Slice slice(Singular start, Singular end){ return slice(start, end, 1); } + /** + * An index to get a slice of its dimension. + * End can be null or All to end at the end. + * + * @param start Where to start the slice. + * @param end Where to end the slice (exclusive). Ends at the end if null or All. + * @see #slice(Singular, Singular, int) + */ public static Slice slice(int start, Singular end){ return slice(start, end, 1); } + /** + * An index to get a slice of its dimension. + * Start can be null or All to start at the beginning. + * + * @param start Where to start the slice. Starts at the beginning if null or All. + * @param end Where to end the slice (exclusive). + * @see #slice(Singular, Singular, int) + */ public static Slice slice(Singular start, int end){ return slice(start, end, 1); } + /** + * An index to get a slice of its dimension. + * + * @param start Where to start the slice. + * @param end Where to end the slice (exclusive). + * @see #slice(Singular, Singular, int) + */ public static Slice slice(int start, int end){ return slice(start, end, 1); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java index b9c4f37a27d..a561d458a35 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java @@ -16,11 +16,17 @@ import org.tensorflow.Operand; import org.tensorflow.op.Index; +import org.tensorflow.op.Index.Singular; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; import org.tensorflow.types.family.TType; +/** + * Helper endpoint methods for Python like indexing. + * + * @see Index + */ @Operator public class Indexing { @@ -77,6 +83,8 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { } if (idx.isEllipsisMask()) { + if(ellipsisMask != 0) + throw new IllegalArgumentException("Can not have two ellipsis in a slice"); ellipsisMask |= 1L << i; } @@ -91,7 +99,56 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } - + /** + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of + * the elements from the `n` dimensional `input` tensor. The subset is chosen using + * a sequence of `m` sparse range specifications encoded into the arguments + * of this function. Note, in some cases + * `m` could be equal to `n`, but this need not be the case. Each + * range specification entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Index#ellipses()}. Ellipses are used to imply zero or more + * dimensions of full-dimension selection and are produced using + * `ellipsis_mask`. For example, `foo[...]` is the identity slice. + *

+ * - A new axis using {@link Index#newAxis()}. This is used to insert a new shape=1 dimension and is + * produced using `new_axis_mask`. For example, `foo[:, ...]` where + * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + *

+ * - A range `begin:end:stride` using {@link Index#slice(Singular, Singular, int) Index.slice()} or {@link Index#all()}. This is used to specify how much to choose from + * a given dimension. `stride` can be any integer but 0. `begin` is an integer + * which represents the index of the first value to select while `end` represents + * the index of the last value to select. The number of values selected in each + * dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. + * `begin` and `end` can be negative where `-1` is the last element, `-2` is + * the second to last. `begin_mask` controls whether to replace the explicitly + * given `begin` with an implicit effective value of `0` if `stride > 0` and + * `-1` if `stride < 0`. `end_mask` is analogous but produces the number + * required to create the largest open interval. For example, given a shape + * `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do + * not assume this is equivalent to `foo[0:-1]` which has an effective `begin` + * and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the + * first dimension of a tensor while dropping the last two (in the original + * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + *

+ * - A single index using {@link Index#point(int)}. This is used to keep only elements that have a given + * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a + * shape `(6,)` tensor. This is encoded in `begin` and `end` and + * `shrink_axis_mask`. + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` + * Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param input + * @param indices The indices to slice. See {@link Index}. + * @return a new instance of StridedSlice + */ @Endpoint(name = "stridedSlice") public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { StridedSliceArgs args = mergeIndexes(indices); @@ -109,6 +166,24 @@ public static StridedSlice stridedSlice(Scope scope, Operan ); } + /** + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable + * `ref` that are selected by the slice parameters. The slice parameters + * `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s + * shape must be exactly the shape produced by the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Index}. + * @return a new instance of StridedSliceAssign + * @see #stridedSlice(Scope, Operand, Index...) + */ @Endpoint(name = "stridedSliceAssign") public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, Index... indices) { StridedSliceArgs args = mergeIndexes(indices); From 1c078ef324739f8b7c1f022e0b1f077ebdaa5afb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 14:41:36 -0800 Subject: [PATCH 03/24] codegen Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 67 ++++++++++++++++++- .../java/org/tensorflow/op/core/Indexing.java | 3 +- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 70fd4ff99e4..ecc19accea3 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -5902,7 +5902,55 @@ public StopGradient stopGradient(Operand input) { } /** - * empty + * Return a strided slice from `input`. + *

+ * The goal of this op is to produce a new tensor with a subset of + * the elements from the `n` dimensional `input` tensor. The subset is chosen using + * a sequence of `m` sparse range specifications encoded into the arguments + * of this function. Note, in some cases + * `m` could be equal to `n`, but this need not be the case. Each + * range specification entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Index#ellipses()}. Ellipses are used to imply zero or more + * dimensions of full-dimension selection and are produced using + * `ellipsis_mask`. For example, `foo[...]` is the identity slice. + *

+ * - A new axis using {@link Index#newAxis()}. This is used to insert a new shape=1 dimension and is + * produced using `new_axis_mask`. For example, `foo[:, ...]` where + * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + *

+ * - A range `begin:end:stride` using {@link Index#slice(Singular, Singular, int) Index.slice()} or {@link Index#all()}. This is used to specify how much to choose from + * a given dimension. `stride` can be any integer but 0. `begin` is an integer + * which represents the index of the first value to select while `end` represents + * the index of the last value to select. The number of values selected in each + * dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. + * `begin` and `end` can be negative where `-1` is the last element, `-2` is + * the second to last. `begin_mask` controls whether to replace the explicitly + * given `begin` with an implicit effective value of `0` if `stride > 0` and + * `-1` if `stride < 0`. `end_mask` is analogous but produces the number + * required to create the largest open interval. For example, given a shape + * `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do + * not assume this is equivalent to `foo[0:-1]` which has an effective `begin` + * and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the + * first dimension of a tensor while dropping the last two (in the original + * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + *

+ * - A single index using {@link Index#point(int)}. This is used to keep only elements that have a given + * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a + * shape `(6,)` tensor. This is encoded in `begin` and `end` and + * `shrink_axis_mask`. + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` + * Only one ellipsis. + * + * @param scope current scope + * @param data type for {@code output()} output + * @param input + * @param indices The indices to slice. See {@link Index}. + * @return a new instance of StridedSlice + * @see Index */ public StridedSlice stridedSlice(Operand input, Index... indices) { return Indexing.stridedSlice(scope, input, indices); @@ -6021,7 +6069,22 @@ public StridedSlice stridedSlice(Operand } /** - * empty + * Assign `value` to the sliced l-value reference of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable + * `ref` that are selected by the slice parameters. The slice parameters + * `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s + * shape must be exactly the shape produced by the slice of `ref`. + * + * @param data type for {@code outputRef()} output + * @param scope current scope + * @param ref the tensor to assign to. + * @param value the value to assign. + * @param indices The indices to slice. See {@link Index}. + * @return a new instance of StridedSliceAssign + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) */ public StridedSliceAssign stridedSliceAssign(Operand ref, Operand value, Index... indices) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java index a561d458a35..0d0c0ce5cc0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java @@ -148,6 +148,7 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { * @param input * @param indices The indices to slice. See {@link Index}. * @return a new instance of StridedSlice + * @see Index */ @Endpoint(name = "stridedSlice") public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { @@ -182,7 +183,7 @@ public static StridedSlice stridedSlice(Scope scope, Operan * @param value the value to assign. * @param indices The indices to slice. See {@link Index}. * @return a new instance of StridedSliceAssign - * @see #stridedSlice(Scope, Operand, Index...) + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) */ @Endpoint(name = "stridedSliceAssign") public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, Index... indices) { From 3abac560a6b9c748547083b925d03b1eb692bed5 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 14:53:16 -0800 Subject: [PATCH 04/24] op test Signed-off-by: Ryan Nett --- .../org/tensorflow/op/core/IndexingTest.java | 46 ++++++++++++++----- 1 file changed, 35 insertions(+), 11 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index f391de4cd66..d338367a285 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -16,9 +16,16 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.ndarray.Shape.of; import org.junit.Test; +import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Index; +import org.tensorflow.op.Scope; +import org.tensorflow.types.TFloat32; public class IndexingTest { @@ -30,25 +37,42 @@ public void testIndexMerge() { Index.all(), Index.newAxis(), Index.ellipses(), - Index.slice(Index.all(), 10), - Index.slice(10, null, 2) + Index.slice(Index.all(), 4), + Index.slice(4, null, 2) } ); - assertArrayEquals(args.begin, new int[]{2, 1, 0, 0, 0, 0, 10}); - assertArrayEquals(args.end, new int[]{3, 2, 0, 0, 0, 10, 0}); - assertArrayEquals(args.strides, new int[]{1, 1, 1, 1, 1, 1, 2}); - assertEquals(args.beginMask, 0b0100100); - assertEquals(args.endMask, 0b1000100); - assertEquals(args.ellipsisMask, 0b0010000); - assertEquals(args.newAxisMask, 0b0001000); - assertEquals(args.shrinkAxisMask, 0b0000001); + assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 10}, args.begin); + assertArrayEquals(new int[]{3, 2, 0, 0, 0, 10, 0}, args.end); + assertArrayEquals(new int[]{1, 1, 1, 1, 1, 1, 2}, args.strides); + assertEquals(0b0100100, args.beginMask); + assertEquals(0b1000100, args.endMask); + assertEquals(0b0010000, args.ellipsisMask); + assertEquals(0b0001000, args.newAxisMask); + assertEquals(0b0000001, args.shrinkAxisMask); } @Test public void testStridedSliceIndex(){ - //TODO test op + try (Graph g = new Graph(); + Session sess = new Session(g)) { + Scope scope = new Scope(g); + long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); + StridedSlice output = Indexing.stridedSlice(scope, op, + Index.point(2), + Index.point(1, true), + Index.all(), + Index.newAxis(), + Index.ellipses(), + Index.slice(Index.all(), 4), + Index.slice(4, null, 2) + ); + try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.data().shape(), "Slice index didn't match expected (Python)"); + } + } } } From 659fd686f6294c8a2ae6ece0988ef76a4f2b7476 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 14:53:45 -0800 Subject: [PATCH 05/24] fix test Signed-off-by: Ryan Nett --- .../src/test/java/org/tensorflow/op/core/IndexingTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index d338367a285..b8cc8109237 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -69,7 +69,7 @@ public void testStridedSliceIndex(){ Index.slice(Index.all(), 4), Index.slice(4, null, 2) ); - try (Tensor result = sess.runner().fetch(op.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { + try (Tensor result = sess.runner().fetch(output.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.data().shape(), "Slice index didn't match expected (Python)"); } } From 73e88905acae5e4819714c8af5541cb71d940a4a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 14:56:30 -0800 Subject: [PATCH 06/24] explain expected shape, fix slice input Signed-off-by: Ryan Nett --- .../org/tensorflow/op/core/IndexingTest.java | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index b8cc8109237..d2d29d59f6a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -29,21 +29,23 @@ public class IndexingTest { + // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] + private static Index[] slice = new Index[]{ + Index.point(2), + Index.point(1, true), + Index.all(), + Index.newAxis(), + Index.ellipses(), + Index.slice(Index.all(), 4), + Index.slice(4, null, 2) + }; + @Test public void testIndexMerge() { - Indexing.StridedSliceArgs args = Indexing.mergeIndexes(new Index[]{ - Index.point(2), - Index.point(1, true), - Index.all(), - Index.newAxis(), - Index.ellipses(), - Index.slice(Index.all(), 4), - Index.slice(4, null, 2) - } - ); + Indexing.StridedSliceArgs args = Indexing.mergeIndexes(slice); - assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 10}, args.begin); - assertArrayEquals(new int[]{3, 2, 0, 0, 0, 10, 0}, args.end); + assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 4}, args.begin); + assertArrayEquals(new int[]{3, 2, 0, 0, 0, 4, 0}, args.end); assertArrayEquals(new int[]{1, 1, 1, 1, 1, 1, 2}, args.strides); assertEquals(0b0100100, args.beginMask); assertEquals(0b1000100, args.endMask); @@ -60,16 +62,9 @@ public void testStridedSliceIndex(){ Scope scope = new Scope(g); long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); - StridedSlice output = Indexing.stridedSlice(scope, op, - Index.point(2), - Index.point(1, true), - Index.all(), - Index.newAxis(), - Index.ellipses(), - Index.slice(Index.all(), 4), - Index.slice(4, null, 2) - ); + StridedSlice output = Indexing.stridedSlice(scope, op, slice); try (Tensor result = sess.runner().fetch(output.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { + // expected shape from Python tensorflow assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.data().shape(), "Slice index didn't match expected (Python)"); } } From 57dc61111ff6ffc3083ebc37bea6f323c7fc189c Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 15:23:12 -0800 Subject: [PATCH 07/24] add a final Signed-off-by: Ryan Nett --- .../src/test/java/org/tensorflow/op/core/IndexingTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index d2d29d59f6a..2f5453017e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -30,7 +30,7 @@ public class IndexingTest { // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] - private static Index[] slice = new Index[]{ + private static final Index[] slice = new Index[]{ Index.point(2), Index.point(1, true), Index.all(), From 0f5211dbd80b57f5b6498a786170b04ecde6eec6 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 18:58:04 -0800 Subject: [PATCH 08/24] fix constructor visibility Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/op/core/Indexing.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java index 0d0c0ce5cc0..a82e544ee67 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java @@ -41,7 +41,7 @@ static class StridedSliceArgs { final long newAxisMask; final long shrinkAxisMask; - public StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask, + private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, long endMask, long ellipsisMask, long newAxisMask, long shrinkAxisMask) { this.begin = begin; this.end = end; From 1c356152ae90e9aad1a00aabc67bc3682ef816ad Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 9 Dec 2020 13:01:11 -0800 Subject: [PATCH 09/24] fix range check Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/op/Index.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java index 82b081b7294..5655ced90c5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java @@ -184,8 +184,8 @@ private Slice(Singular start, Singular end, int stride) { false, false); - if(stride < 1){ - throw new IllegalArgumentException("Can not have a stride < 1"); + if(stride != 0){ + throw new IllegalArgumentException("Can not have a stride of 0"); } } } From e446f3a1d1bdeb32b4643b8af8ecb237e546bee6 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 9 Dec 2020 13:01:34 -0800 Subject: [PATCH 10/24] start of adapting ndarray indexing Signed-off-by: Ryan Nett --- .../impl/dimension/DimensionalSpace.java | 61 +++++++++- .../org/tensorflow/ndarray/index/All.java | 12 +- .../java/org/tensorflow/ndarray/index/At.java | 31 +++++- .../tensorflow/ndarray/index/Ellipsis.java | 48 ++++++++ .../org/tensorflow/ndarray/index/Index.java | 15 +++ .../org/tensorflow/ndarray/index/Indices.java | 67 ++++++++++- .../org/tensorflow/ndarray/index/NewAxis.java | 53 +++++++++ .../org/tensorflow/ndarray/index/Slice.java | 104 ++++++++++++++++++ .../tensorflow/ndarray/index/TensorIndex.java | 50 +++++++++ 9 files changed, 426 insertions(+), 15 deletions(-) create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index e4bdc53c713..1e4303e04f9 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -39,20 +39,38 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { throw new ArrayIndexOutOfBoundsException(); } int dimIdx = 0; + int indexIdx = 0; int newDimIdx = 0; int segmentationIdx = -1; long initialOffset = 0; - Dimension[] newDimensions = new Dimension[dimensions.length]; - while (dimIdx < indices.length) { + int newAxes = 0; + boolean seenEllipsis = false; + for(Index idx : indices){ + if(idx.isNewAxis()){ + newAxes += 1; + } + if(idx.isEllipsis()){ + if(seenEllipsis){ + throw new IllegalArgumentException("Only one ellipsis allowed"); + } else { + seenEllipsis = true; + } + } + } + int newLength = dimensions.length + newAxes; + + Dimension[] newDimensions = new Dimension[newLength]; + while (indexIdx < indices.length) { if (indices[dimIdx].isPoint()) { // When an index targets a single point in a given dimension, calculate the offset of this // point and cumulate the offset of any subsequent point as well long offset = 0; do { - offset += indices[dimIdx].mapCoordinate(0, dimensions[dimIdx]); - } while (++dimIdx < indices.length && indices[dimIdx].isPoint()); + offset += indices[indexIdx].mapCoordinate(0, dimensions[dimIdx]); + dimIdx++; + } while (++indexIdx < indices.length && indices[indexIdx].isPoint()); // If this is the first index, then the offset is the position of the whole dimension // space within the original one. If not, then we apply the offset to the last vectorial @@ -65,14 +83,47 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx - 1; } + } else if(indices[indexIdx].isNewAxis()) { + long newSize; + if(dimIdx == 0){ + // includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues + // elsewhere + newSize = dimensions[0].numElements() * dimensions[0].elementSize(); + } else { + newSize = dimensions[dimIdx - 1].elementSize(); + } + + newDimensions[newDimIdx] = new Axis(1, newSize); + segmentationIdx = newDimIdx; // is this correct? + ++newDimIdx; + ++indexIdx; + } else if(indices[indexIdx].isEllipsis()){ + int remainingDimensions = dimensions.length - dimIdx; + int requiredDimensions = 0; + for(int i = indexIdx + 1 ; i < indices.length ; i++){ + if(!indices[i].isNewAxis()){ + requiredDimensions++; + } + } + // while the number of dimensions left < the number of indices that consume axes + while(remainingDimensions > requiredDimensions){ + Dimension dim = dimensions[dimIdx++]; + if (dim.isSegmented()) { + segmentationIdx = newDimIdx; + } + newDimensions[newDimIdx++] = dim; + remainingDimensions--; + } + indexIdx++; } else { // Map any other index to the appropriate dimension of this space - Dimension newDimension = indices[dimIdx].apply(dimensions[dimIdx++]); + Dimension newDimension = indices[indexIdx].apply(dimensions[dimIdx++]); newDimensions[newDimIdx] = newDimension; if (newDimension.isSegmented()) { segmentationIdx = newDimIdx; } ++newDimIdx; + ++indexIdx; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index b38e33d5e22..1efd6bafe53 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -18,7 +18,7 @@ import org.tensorflow.ndarray.impl.dimension.Dimension; -final class All implements Index { +final class All implements TensorIndex { static final All INSTANCE = new All(); @@ -39,4 +39,14 @@ public Dimension apply(Dimension dim) { private All() { } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 5d92ee3286b..7f28f209c00 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -18,7 +18,7 @@ import org.tensorflow.ndarray.impl.dimension.Dimension; -final class At implements Index { +final class At implements TensorIndex { @Override public long numElements(Dimension dim) { @@ -27,22 +27,45 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { + long coord = this.coord > 0 ? this.coord : dim.numElements() - this.coord; return dim.positionOf(coord); // TODO validate coordinate is 0? } @Override public Dimension apply(Dimension dim) { - throw new IllegalStateException(); // FIXME? + if(keepDim){ + return dim.withIndex(this); + } + else { + throw new IllegalStateException(); // FIXME? + } } @Override public boolean isPoint() { - return true; + return !keepDim; } - At(long coord) { + At(long coord, boolean keepDim) { this.coord = coord; + this.keepDim = keepDim; } private final long coord; + private final boolean keepDim; + + @Override + public long begin() { + return coord; + } + + @Override + public long end() { + return coord + 1; + } + + @Override + public boolean shrinkAxisMask() { + return !keepDim; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java new file mode 100644 index 00000000000..f3e9247431d --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -0,0 +1,48 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Ellipsis implements TensorIndex{ + + static final Ellipsis INSTANCE = new Ellipsis(); + + private Ellipsis(){ + + } + + @Override + public long numElements(Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public boolean isEllipsis() { + return true; + } + + @Override + public boolean ellipsisMask() { + return true; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index da6aa9049f6..b459ff7a99d 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -74,4 +74,19 @@ default Dimension apply(Dimension dim) { default boolean isPoint() { return false; } + + /** + * Returns true if this index is a new axis, adding a dimension of size 1 + */ + default boolean isNewAxis() { + return false; + } + + /** + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible + * (and applying all() to them) + */ + default boolean isEllipsis() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index abc72195c82..5c449df8a43 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -40,8 +40,8 @@ public final class Indices { * @param coord coordinate of the element on the indexed axis * @return index */ - public static Index at(long coord) { - return new At(coord); + public static TensorIndex at(long coord) { + return new At(coord, false); } /** @@ -54,11 +54,50 @@ public static Index at(long coord) { * @return index * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) */ - public static Index at(NdArray coord) { + public static TensorIndex at(NdArray coord) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } - return new At(coord.getObject().longValue()); + return new At(coord.getObject().longValue(), false); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

When this index is applied to a given dimension, the dimension is resolved as a + * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. + * If {@code} keepDim is true, the dimension is collapsed down to one element. + * + *

For example, given a 3D matrix on the axis [x, y, z], if + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its + * number of elements is {@code x.numElements()} + * + * @param coord coordinate of the element on the indexed axis + * @param keepDim whether to remove the dimension. + * @return index + */ + public static TensorIndex at(long coord, boolean keepDim) { + return new At(coord, keepDim); + } + + /** + * A coordinate that selects a specific element on a given dimension. + * + *

This is equivalent to call {@link #at(long, boolean)} but where the value of the coordinate is + * provided by an N-dimensional array. + *

+ * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed. + * + * @param coord scalar indicating the coordinate of the element on the indexed axis + * @return index + * @param keepDim whether to remove the dimension. + * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) + */ + public static TensorIndex at(NdArray coord, boolean keepDim) { + if (coord.rank() > 0) { + throw new IllegalRankException("Only scalars are accepted as a value index"); + } + return new At(coord.getObject().longValue(), keepDim); } /** @@ -72,7 +111,7 @@ public static Index at(NdArray coord) { * * @return index */ - public static Index all() { + public static TensorIndex all() { return All.INSTANCE; } @@ -216,4 +255,22 @@ public static Index flip() { public static Index hyperslab(long start, long stride, long count, long block) { return new Hyperslab(start, stride, count, block); } + + //TODO comments, tests, remove extra classes in favor of helper methods + + public static TensorIndex newAxis(){ + return NewAxis.INSTANCE; + } + + public static TensorIndex ellipsis(){ + return Ellipsis.INSTANCE; + } + + public static TensorIndex expand(){ + return ellipsis(); + } + + public static TensorIndex slice(Long start, Long end, long stride){ + return new Slice(start, end, stride); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java new file mode 100644 index 00000000000..f29f7058732 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -0,0 +1,53 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class NewAxis implements TensorIndex { + + static final NewAxis INSTANCE = new NewAxis(); + + private NewAxis(){ + + } + + @Override + public long numElements(Dimension dim) { + return 1; + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return coordinate; + } + + @Override + public Dimension apply(Dimension dim) { + throw new IllegalStateException(); + } + + @Override + public boolean isNewAxis() { + return true; + } + + @Override + public boolean newAxisMask() { + return true; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java new file mode 100644 index 00000000000..05ffa60adeb --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -0,0 +1,104 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Slice implements TensorIndex { + + private final Long start; + private final Long end; + private final long stride; + + private long start(Dimension dim){ + if(start == null){ + if(stride > 0){ + return 0; + } else { + return dim.numElements() - 1; // it's inclusive + } + } else if(start < 0){ + return dim.numElements() + start; + } else { + return start; + } + } + + private long end(Dimension dim){ + if(end == null){ + if(stride > 0){ + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } else if(end < 0){ + return dim.numElements() + end; + } else { + return end; + } + } + + Slice(Long start, Long end, long stride) { + this.start = start; + this.end = end; + this.stride = stride; + + if(stride != 0){ + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + if(stride < 0){ + length *= -1; + } + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start == null ? 0 : start; + } + + @Override + public long end() { + return end == null ? 0 : end; + } + + @Override + public long stride() { + return stride; + } + + @Override + public boolean beginMask() { + return start == null; + } + + @Override + public boolean endMask() { + return end == null; + } +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java new file mode 100644 index 00000000000..c66dd18de10 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java @@ -0,0 +1,50 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +public interface TensorIndex extends Index{ + default long begin(){ + return 0; + } + default long end(){ + return 0; + } + + default long stride(){ + return 1; + } + + default boolean beginMask(){ + return false; + } + + default boolean endMask(){ + return false; + } + + default boolean ellipsisMask(){ + return false; + } + + default boolean newAxisMask(){ + return false; + } + + default boolean shrinkAxisMask(){ + return false; + } +} From 0e7aa1768cc9257632cce52447869fd33612c0cf Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 27 Dec 2020 19:22:41 -0800 Subject: [PATCH 11/24] remove old Index class Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/index/Indices.java | 28 ++ .../annotations/org/tensorflow/op/Ops.java | 27 +- .../main/java/org/tensorflow/op/Index.java | 371 ------------------ ...{Indexing.java => StridedSliceHelper.java} | 51 +-- .../test/java/org/tensorflow/IndexTest.java | 14 +- .../org/tensorflow/op/core/IndexingTest.java | 24 +- 6 files changed, 87 insertions(+), 428 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java rename tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/{Indexing.java => StridedSliceHelper.java} (83%) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index 5c449df8a43..27592b02327 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -258,6 +258,10 @@ public static Index hyperslab(long start, long stride, long count, long block) { //TODO comments, tests, remove extra classes in favor of helper methods + /** + * + * @return + */ public static TensorIndex newAxis(){ return NewAxis.INSTANCE; } @@ -270,7 +274,31 @@ public static TensorIndex expand(){ return ellipsis(); } + public static TensorIndex slice(Long start, Long end){ + return slice(start, end, 1); + } + public static TensorIndex slice(Long start, Long end, long stride){ return new Slice(start, end, stride); } + + public static TensorIndex slice(Integer start, int end){ + return intSlice(start, end, 1); + } + + public static TensorIndex slice(int start, Integer end){ + return intSlice(start, end, 1); + } + + public static TensorIndex slice(Integer start, int end, long stride){ + return intSlice(start, end, stride); + } + + public static TensorIndex slice(int start, Integer end, long stride){ + return intSlice(start, end, stride); + } + + private static TensorIndex intSlice(Integer start, Integer end, long stride){ + return new Slice(start == null ? null : start.longValue(), end == null ? null : end.longValue(), stride); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ecc19accea3..0a95dbce72a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -38,6 +38,7 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.ndarray.index.TensorIndex; import org.tensorflow.op.core.Abort; import org.tensorflow.op.core.All; import org.tensorflow.op.core.Any; @@ -93,7 +94,6 @@ import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.IdentityN; import org.tensorflow.op.core.ImmutableConst; -import org.tensorflow.op.core.Indexing; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.InitializeTable; import org.tensorflow.op.core.InitializeTableFromTextFile; @@ -211,6 +211,7 @@ import org.tensorflow.op.core.StridedSlice; import org.tensorflow.op.core.StridedSliceAssign; import org.tensorflow.op.core.StridedSliceGrad; +import org.tensorflow.op.core.StridedSliceHelper; import org.tensorflow.op.core.Sum; import org.tensorflow.op.core.SwitchCond; import org.tensorflow.op.core.TemporaryVariable; @@ -5911,15 +5912,15 @@ public StopGradient stopGradient(Operand input) { * `m` could be equal to `n`, but this need not be the case. Each * range specification entry can be one of the following: *

- * - An ellipsis (...) using {@link Index#ellipses()}. Ellipses are used to imply zero or more + * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more * dimensions of full-dimension selection and are produced using * `ellipsis_mask`. For example, `foo[...]` is the identity slice. *

- * - A new axis using {@link Index#newAxis()}. This is used to insert a new shape=1 dimension and is + * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is * produced using `new_axis_mask`. For example, `foo[:, ...]` where * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. *

- * - A range `begin:end:stride` using {@link Index#slice(Singular, Singular, int) Index.slice()} or {@link Index#all()}. This is used to specify how much to choose from + * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify how much to choose from * a given dimension. `stride` can be any integer but 0. `begin` is an integer * which represents the index of the first value to select while `end` represents * the index of the last value to select. The number of values selected in each @@ -5935,7 +5936,7 @@ public StopGradient stopGradient(Operand input) { * first dimension of a tensor while dropping the last two (in the original * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. *

- * - A single index using {@link Index#point(int)}. This is used to keep only elements that have a given + * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a * shape `(6,)` tensor. This is encoded in `begin` and `end` and * `shrink_axis_mask`. @@ -5948,12 +5949,12 @@ public StopGradient stopGradient(Operand input) { * @param scope current scope * @param data type for {@code output()} output * @param input - * @param indices The indices to slice. See {@link Index}. + * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSlice - * @see Index + * @see Indices */ - public StridedSlice stridedSlice(Operand input, Index... indices) { - return Indexing.stridedSlice(scope, input, indices); + public StridedSlice stridedSlice(Operand input, TensorIndex... indices) { + return StridedSliceHelper.stridedSlice(scope, input, indices); } /** @@ -6082,13 +6083,13 @@ public StridedSlice stridedSlice(Operand * @param scope current scope * @param ref the tensor to assign to. * @param value the value to assign. - * @param indices The indices to slice. See {@link Index}. + * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSliceAssign - * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + * @see org.tensorflow.op.Ops#stridedSlice(Operand, TensorIndex...) */ public StridedSliceAssign stridedSliceAssign(Operand ref, - Operand value, Index... indices) { - return Indexing.stridedSliceAssign(scope, ref, value, indices); + Operand value, TensorIndex... indices) { + return StridedSliceHelper.stridedSliceAssign(scope, ref, value, indices); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java deleted file mode 100644 index 5655ced90c5..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/Index.java +++ /dev/null @@ -1,371 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow.op; - -import org.tensorflow.op.annotation.Endpoint; - -/** - * Numpy-like indexing. Supports slices, stride slices, open slices, all, newaxis, and ellipsis. - *

- * Examples: - *

{@code
- * x[1:-1, :, tf.newaxis, ...]
- * // becomes
- * stridedSlice(x, Index.slice(1, -1), Index.all(), Index.newAxis(), Index.ellipsis())
- *
- *
- * x[2, 10:, :-10, 2:-2:2]
- * // becomes
- * stridedSlice(x, Index.point(2), Index.slice(10, null), Index.slice(null, -10), Index.slice(2, -2, 2))
- * }
- * - */ -public abstract class Index { - - private final int begin; - private final int end; - private final int stride; - private final boolean beginMask; - private final boolean endMask; - private final boolean ellipsisMask; - private final boolean newAxisMask; - private final boolean shrinkAxisMask; - - private Index(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, - boolean newAxisMask, boolean shrinkAxisMask) { - this.begin = begin; - this.end = end; - this.stride = stride; - this.beginMask = beginMask; - this.endMask = endMask; - this.ellipsisMask = ellipsisMask; - this.newAxisMask = newAxisMask; - this.shrinkAxisMask = shrinkAxisMask; - } - - /** - * @return the beginning index of the slice. - */ - public int getBegin() { - return begin; - } - - /** - * @return the end (exclusive) index of the slice. - */ - public int getEnd() { - return end; - } - - /** - * @return the stride of the slice. - */ - public int getStride() { - return stride; - } - - /** - * @return whether to begin at the beginning. - */ - public boolean isBeginMask() { - return beginMask; - } - - /** - * @return whether to end at the end. - */ - public boolean isEndMask() { - return endMask; - } - - /** - * @return is this index an {@link Ellipses} - */ - public boolean isEllipsisMask() { - return ellipsisMask; - } - - /** - * @return should this index add a new dimension. - */ - public boolean isNewAxisMask() { - return newAxisMask; - } - - /** - * @return should this index shrink its dimension. - */ - public boolean isShrinkAxisMask() { - return shrinkAxisMask; - } - - /** - * An index that can be used as the start or end of a slice. - */ - public static abstract class Singular extends Index { - - private Singular(int begin, int end, int stride, boolean beginMask, boolean endMask, boolean ellipsisMask, - boolean newAxisMask, boolean shrinkAxisMask) { - super(begin, end, stride, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); - } - } - - /** - * An index that gets the entire dimension. - */ - public static class All extends Singular { - - private All() { - super(0, 0, 1, true, true, false, false, false); - } - } - - /** - * An index that gets a single point from its dimension, collapsing it by default. - */ - public static class Point extends Singular { - - private final int index; - - private Point(int index, boolean keepDim) { - super(index, index + 1, 1, false, false, false, false, !keepDim); - this.index = index; - } - - public int getIndex() { - return index; - } - } - - /** - * An index that adds a new dimension of size 1 where it is used. - */ - public static class NewAxis extends Index { - - private NewAxis() { - super(0, 0, 1, false, false, false, true, false); - } - } - - /** - * An index that expands to get all possible dimensions. - */ - public static class Ellipses extends Index { - - private Ellipses() { - super(0, 0, 1, false, false, true, false, false); - } - } - - /** - * An index to get a slice of its dimension, with an optional stride. - */ - public static class Slice extends Index { - - private Slice(Singular start, Singular end, int stride) { - super(start instanceof Point ? ((Point) start).index : 0, - end instanceof Point ? ((Point) end).index : 0, - stride, - start instanceof All, - end instanceof All, - false, - false, - false); - - if(stride != 0){ - throw new IllegalArgumentException("Can not have a stride of 0"); - } - } - } - - /** - * An index that gets the entire dimension. - *

- * Equivalent to Python's {@code :}. - */ - public static All all(){ - return new All(); - } - - /** - * An index that gets the value at the given position, and collapses the dimension (removing it). - *

- * Equivalent to Python's indexing, and supports negative values in the same way. - * - * @param index The position to get. - */ - public static Point point(int index){ - return new Point(index, false); - } - - - /** - * An index that gets the value at the given position, and collapses the dimension (removing it) if keepDim is false. - * - * @param index The position to get. - * @param keepDim Whether to keep the dimension as size 1. - */ - public static Point point(int index, boolean keepDim){ - return new Point(index, keepDim); - } - - /** - * An index that adds a new dimension of size 1 where it is used. - *

- * Equivalent to Python's {@code np.newaxis}, {@code tf.newaxis} or {@code None}. - */ - public static NewAxis newAxis(){ - return new NewAxis(); - } - - /** - * An index that expands to get all possible dimensions. - *

- * Equivalent to Python's {@code ...}. - */ - public static Ellipses ellipses(){ - return new Ellipses(); - } - - /** - * An index to get a slice of its dimension. - * Start and end can be null or All to start or end at the beginning or end, respectively. - *

- * Equivalent to Python's {@code :} slicing syntax: - *

{@code
-   * :
-   * // becomes
-   * Index.all()
-   * Index.slice(null, null)
-   * Index.slice(Index.all(), Index.all())
-   *
-   * 2:
-   * // becomes
-   * Index.slice(2, null)
-   *
-   * :2
-   * // becomes
-   * Index.slice(null, 2)
-   *
-   * 2:10
-   * // becomes
-   * Index.slice(2, 10)
-   *
-   * :2
-   * // becomes
-   * Index.slice(null, null, 2)
-   *
-   * 2:10:2
-   * //becomes
-   * Index.slice(2, 10, 2)
-   * }
- * - * @param start Where to start the slice. Starts at the beginning if null or All. - * @param end Where to end the slice (exclusive). Ends at the end if null or All. - * @param stride The stride. - */ - public static Slice slice(Singular start, Singular end, int stride){ - return new Slice(start == null ? all() : start, end == null ? all() : end, stride); - } - - - /** - * An index to get a slice of its dimension. - * End can be null or All to end at the end. - * - * @param start Where to start the slice. - * @param end Where to end the slice (exclusive). Ends at the end if null or All. - * @param stride The stride. - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(int start, Singular end, int stride){ - return slice(point(start), end, stride); - } - - - /** - * An index to get a slice of its dimension. - * Start can be null or All to start at the beginning. - * - * @param start Where to start the slice. Starts at the beginning if null or All. - * @param end Where to end the slice (exclusive). - * @param stride The stride. - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(Singular start, int end, int stride){ - return slice(start, point(end), stride); - } - - - /** - * An index to get a slice of its dimension. - * - * @param start Where to start the slice. - * @param end Where to end the slice (exclusive). - * @param stride The stride. - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(int start, int end, int stride){ - return slice(point(start), point(end), stride); - } - - /** - * An index to get a slice of its dimension. - * Start and end can be null or All to start or end at the beginning or end, respectively. - * - * @param start Where to start the slice. Starts at the beginning if null or All. - * @param end Where to end the slice (exclusive). Ends at the end if null or All. - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(Singular start, Singular end){ - return slice(start, end, 1); - } - - /** - * An index to get a slice of its dimension. - * End can be null or All to end at the end. - * - * @param start Where to start the slice. - * @param end Where to end the slice (exclusive). Ends at the end if null or All. - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(int start, Singular end){ - return slice(start, end, 1); - } - - /** - * An index to get a slice of its dimension. - * Start can be null or All to start at the beginning. - * - * @param start Where to start the slice. Starts at the beginning if null or All. - * @param end Where to end the slice (exclusive). - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(Singular start, int end){ - return slice(start, end, 1); - } - - /** - * An index to get a slice of its dimension. - * - * @param start Where to start the slice. - * @param end Where to end the slice (exclusive). - * @see #slice(Singular, Singular, int) - */ - public static Slice slice(int start, int end){ - return slice(start, end, 1); - } - -} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java similarity index 83% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index a82e544ee67..e2643b41d2c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Indexing.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -15,8 +15,8 @@ package org.tensorflow.op.core; import org.tensorflow.Operand; -import org.tensorflow.op.Index; -import org.tensorflow.op.Index.Singular; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.TensorIndex; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -25,10 +25,10 @@ /** * Helper endpoint methods for Python like indexing. * - * @see Index + * @see org.tensorflow.ndarray.index.Indices */ @Operator -public class Indexing { +public class StridedSliceHelper { static class StridedSliceArgs { @@ -54,7 +54,7 @@ private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, } } - static StridedSliceArgs mergeIndexes(Index[] indices) { + static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { int[] begin = new int[indices.length]; int[] end = new int[indices.length]; int[] strides = new int[indices.length]; @@ -65,34 +65,35 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { long shrinkAxisMask = 0; for (int i = 0; i < indices.length; i++) { - Index idx = indices[i]; + TensorIndex idx = indices[i]; if (idx == null) { - idx = Index.all(); + idx = Indices.all(); } - begin[i] = idx.getBegin(); - end[i] = idx.getEnd(); - strides[i] = idx.getStride(); + //TODO warnings for out of bounds? + begin[i] = (int) idx.begin(); + end[i] = (int) idx.end(); + strides[i] = (int) idx.stride(); - if (idx.isBeginMask()) { + if (idx.beginMask()) { beginMask |= 1L << i; } - if (idx.isEndMask()) { + if (idx.endMask()) { endMask |= 1L << i; } - if (idx.isEllipsisMask()) { + if (idx.ellipsisMask()) { if(ellipsisMask != 0) throw new IllegalArgumentException("Can not have two ellipsis in a slice"); ellipsisMask |= 1L << i; } - if (idx.isNewAxisMask()) { + if (idx.newAxisMask()) { newAxisMask |= 1L << i; } - if (idx.isShrinkAxisMask()) { + if (idx.shrinkAxisMask()) { shrinkAxisMask |= 1L << i; } } @@ -109,15 +110,15 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { * `m` could be equal to `n`, but this need not be the case. Each * range specification entry can be one of the following: *

- * - An ellipsis (...) using {@link Index#ellipses()}. Ellipses are used to imply zero or more + * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more * dimensions of full-dimension selection and are produced using * `ellipsis_mask`. For example, `foo[...]` is the identity slice. *

- * - A new axis using {@link Index#newAxis()}. This is used to insert a new shape=1 dimension and is + * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is * produced using `new_axis_mask`. For example, `foo[:, ...]` where * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. *

- * - A range `begin:end:stride` using {@link Index#slice(Singular, Singular, int) Index.slice()} or {@link Index#all()}. This is used to specify how much to choose from + * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify how much to choose from * a given dimension. `stride` can be any integer but 0. `begin` is an integer * which represents the index of the first value to select while `end` represents * the index of the last value to select. The number of values selected in each @@ -133,7 +134,7 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { * first dimension of a tensor while dropping the last two (in the original * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. *

- * - A single index using {@link Index#point(int)}. This is used to keep only elements that have a given + * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a * shape `(6,)` tensor. This is encoded in `begin` and `end` and * `shrink_axis_mask`. @@ -146,12 +147,12 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { * @param scope current scope * @param data type for {@code output()} output * @param input - * @param indices The indices to slice. See {@link Index}. + * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSlice - * @see Index + * @see Indices */ @Endpoint(name = "stridedSlice") - public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { + public static StridedSlice stridedSlice(Scope scope, Operand input, TensorIndex... indices) { StridedSliceArgs args = mergeIndexes(indices); return StridedSlice.create( scope, @@ -181,12 +182,12 @@ public static StridedSlice stridedSlice(Scope scope, Operan * @param scope current scope * @param ref the tensor to assign to. * @param value the value to assign. - * @param indices The indices to slice. See {@link Index}. + * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSliceAssign - * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) + * @see org.tensorflow.op.Ops#stridedSlice(Operand, TensorIndex...) */ @Endpoint(name = "stridedSliceAssign") - public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, Index... indices) { + public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, TensorIndex... indices) { StridedSliceArgs args = mergeIndexes(indices); return StridedSliceAssign.create( scope, diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java index 18a71ab9f98..ec37418baaf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java @@ -17,27 +17,27 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.Test; -import org.tensorflow.op.Index; +import org.tensorflow.ndarray.index.Indices; public class IndexTest { @Test public void testNullConversions(){ - assertTrue(Index.slice(null, 0).isBeginMask(), + assertTrue(Indices.slice(null, 0).beginMask(), "Passed null for slice start but didn't set begin mask"); - assertTrue(Index.slice(null, Index.point(0)).isBeginMask(), + assertTrue(Indices.slice(null, 0).beginMask(), "Passed null for slice start but didn't set begin mask"); - assertTrue(Index.slice(null, null).isBeginMask(), + assertTrue(Indices.slice(null, null).beginMask(), "Passed null for slice start but didn't set begin mask"); - assertTrue(Index.slice(0, null).isEndMask(), + assertTrue(Indices.slice(0, null).endMask(), "Passed null for slice end but didn't set end mask"); - assertTrue(Index.slice(Index.point(0), null).isEndMask(), + assertTrue(Indices.slice(0, null).endMask(), "Passed null for slice end but didn't set end mask"); - assertTrue(Index.slice(null, null).isEndMask(), + assertTrue(Indices.slice(null, null).endMask(), "Passed null for slice end but didn't set end mask"); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index 2f5453017e7..b32d33f7948 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -16,33 +16,33 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.ndarray.Shape.of; import org.junit.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Index; +import org.tensorflow.ndarray.index.Indices; +import org.tensorflow.ndarray.index.TensorIndex; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; public class IndexingTest { // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] - private static final Index[] slice = new Index[]{ - Index.point(2), - Index.point(1, true), - Index.all(), - Index.newAxis(), - Index.ellipses(), - Index.slice(Index.all(), 4), - Index.slice(4, null, 2) + private static final TensorIndex[] slice = new TensorIndex[]{ + Indices.at(2), + Indices.at(1, true), + Indices.all(), + Indices.newAxis(), + Indices.ellipsis(), + Indices.slice(null, 4), + Indices.slice(4, null, 2) }; @Test public void testIndexMerge() { - Indexing.StridedSliceArgs args = Indexing.mergeIndexes(slice); + StridedSliceHelper.StridedSliceArgs args = StridedSliceHelper.mergeIndexes(slice); assertArrayEquals(new int[]{2, 1, 0, 0, 0, 0, 4}, args.begin); assertArrayEquals(new int[]{3, 2, 0, 0, 0, 4, 0}, args.end); @@ -62,7 +62,7 @@ public void testStridedSliceIndex(){ Scope scope = new Scope(g); long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); - StridedSlice output = Indexing.stridedSlice(scope, op, slice); + StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); try (Tensor result = sess.runner().fetch(output.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { // expected shape from Python tensorflow assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.data().shape(), "Slice index didn't match expected (Python)"); From 6383df43874b029b77673d8a36707150e130e7e9 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 27 Dec 2020 19:28:59 -0800 Subject: [PATCH 12/24] test fix for rebase Signed-off-by: Ryan Nett --- .../src/test/java/org/tensorflow/op/core/IndexingTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index b32d33f7948..259a8d66d24 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -61,11 +61,11 @@ public void testStridedSliceIndex(){ Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; - Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.DTYPE); + Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); - try (Tensor result = sess.runner().fetch(output.asOutput()).run().get(0).expect(TFloat32.DTYPE)) { + try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) { // expected shape from Python tensorflow - assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.data().shape(), "Slice index didn't match expected (Python)"); + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)"); } } } From 9b1da7713b0e5cf7740f0b755c6c9b41a3eec8fd Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 27 Dec 2020 19:31:26 -0800 Subject: [PATCH 13/24] test fix for rebase Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/index/Slice.java | 2 +- .../org/tensorflow/ndarray}/IndexTest.java | 28 ++++++++++--------- 2 files changed, 16 insertions(+), 14 deletions(-) rename {tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow => ndarray/src/test/java/org/tensorflow/ndarray}/IndexTest.java (54%) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index 05ffa60adeb..a1082ebb887 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -57,7 +57,7 @@ private long end(Dimension dim){ this.end = end; this.stride = stride; - if(stride != 0){ + if(stride == 0){ throw new IllegalArgumentException("Can not have a stride of 0"); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java similarity index 54% rename from tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java rename to ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java index ec37418baaf..a0590b25cee 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/IndexTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -1,22 +1,24 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -package org.tensorflow; + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray; import static org.junit.jupiter.api.Assertions.assertTrue; -import org.junit.Test; +import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.index.Indices; public class IndexTest { From f2c49661e70122fdd2b90741575844d3743905b4 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 4 Jan 2021 21:05:01 -0800 Subject: [PATCH 14/24] Out of bounds warnings Signed-off-by: Ryan Nett --- .../org/tensorflow/op/core/StridedSliceHelper.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index e2643b41d2c..37111979d3d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -28,7 +28,7 @@ * @see org.tensorflow.ndarray.index.Indices */ @Operator -public class StridedSliceHelper { +public abstract class StridedSliceHelper { static class StridedSliceArgs { @@ -70,10 +70,17 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { idx = Indices.all(); } - //TODO warnings for out of bounds? begin[i] = (int) idx.begin(); + if(begin[i] != idx.begin()) + throw new IllegalArgumentException("Can't convert long begin value to int for index " + idx + ": Out of bounds"); + end[i] = (int) idx.end(); + if(end[i] != idx.end()) + throw new IllegalArgumentException("Can't convert long end value to int for index " + idx + ": Out of bounds"); + strides[i] = (int) idx.stride(); + if(strides[i] != idx.stride()) + throw new IllegalArgumentException("Can't convert long stride value to int for index " + idx + ": Out of bounds"); if (idx.beginMask()) { beginMask |= 1L << i; From 35b7aaaac789948465f1b63abd63544ae1523efd Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 4 Jan 2021 21:05:27 -0800 Subject: [PATCH 15/24] Remove extra classes, plus a few fixes Signed-off-by: Ryan Nett --- .../impl/dimension/DimensionalSpace.java | 4 +- .../java/org/tensorflow/ndarray/index/At.java | 2 +- .../org/tensorflow/ndarray/index/Even.java | 37 ----- .../org/tensorflow/ndarray/index/Flip.java | 34 ----- .../org/tensorflow/ndarray/index/From.java | 38 ----- .../org/tensorflow/ndarray/index/Indices.java | 137 ++++++++++++++---- .../org/tensorflow/ndarray/index/Odd.java | 37 ----- .../org/tensorflow/ndarray/index/Range.java | 40 ----- .../org/tensorflow/ndarray/index/Slice.java | 3 - .../org/tensorflow/ndarray/index/Step.java | 38 ----- .../java/org/tensorflow/ndarray/index/To.java | 38 ----- 11 files changed, 111 insertions(+), 297 deletions(-) delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/From.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/To.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index 1e4303e04f9..55440b7d719 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -35,7 +35,7 @@ public static DimensionalSpace create(Shape shape) { } public RelativeDimensionalSpace mapTo(Index[] indices) { - if (dimensions == null || indices.length > dimensions.length) { + if (dimensions == null) { throw new ArrayIndexOutOfBoundsException(); } int dimIdx = 0; @@ -63,7 +63,7 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { Dimension[] newDimensions = new Dimension[newLength]; while (indexIdx < indices.length) { - if (indices[dimIdx].isPoint()) { + if (indices[indexIdx].isPoint()) { // When an index targets a single point in a given dimension, calculate the offset of this // point and cumulate the offset of any subsequent point as well long offset = 0; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 7f28f209c00..3fbdab82095 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -27,7 +27,7 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { - long coord = this.coord > 0 ? this.coord : dim.numElements() - this.coord; + long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; return dim.positionOf(coord); // TODO validate coordinate is 0? } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java deleted file mode 100644 index 54f53853c32..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Even.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Even implements Index { - - static final Even INSTANCE = new Even(); - - @Override - public long numElements(Dimension dim) { - return (dim.numElements() >> 1) + (dim.numElements() % 2); - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate << 1; - } - - private Even() { - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java deleted file mode 100644 index 7914d8faad5..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Flip.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Flip implements Index { - - static final Flip INSTANCE = new Flip(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements(); - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return dim.numElements() - coordinate - 1; - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java deleted file mode 100644 index c541e8370b2..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/From.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class From implements Index { - - @Override - public long numElements(Dimension dim) { - return dim.numElements() - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - From(long start) { - this.start = start; - } - - private final long start; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index 27592b02327..42018dc9ca3 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -157,8 +157,8 @@ public static Index seq(NdArray coords) { * * @return index */ - public static Index even() { - return Even.INSTANCE; + public static TensorIndex even() { + return slice(null, null, 2); } /** @@ -170,8 +170,8 @@ public static Index even() { * * @return index */ - public static Index odd() { - return Odd.INSTANCE; + public static TensorIndex odd() { + return slice(1, null, 2); } /** @@ -183,8 +183,8 @@ public static Index odd() { * @param stepLength the number of elements between each steps * @return index */ - public static Index step(long stepLength) { - return new Step(stepLength); + public static TensorIndex step(long stepLength) { + return slice(null, null, stepLength); } /** @@ -197,8 +197,8 @@ public static Index step(long stepLength) { * @param start coordinate of the first element of the sequence * @return index */ - public static Index from(long start) { - return new From(start); + public static TensorIndex from(long start) { + return slice(start, null); } /** @@ -211,8 +211,8 @@ public static Index from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index to(long end) { - return new To(end); + public static TensorIndex to(long end) { + return slice(null, end); } /** @@ -225,8 +225,8 @@ public static Index to(long end) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index range(long start, long end) { - return new Range(start, end); + public static TensorIndex range(long start, long end) { + return slice(start, end); } /** @@ -237,8 +237,8 @@ public static Index range(long start, long end) { * * @return index */ - public static Index flip() { - return Flip.INSTANCE; + public static TensorIndex flip() { + return slice(null, null, -1); } /** @@ -256,49 +256,128 @@ public static Index hyperslab(long start, long stride, long count, long block) { return new Hyperslab(start, stride, count, block); } - //TODO comments, tests, remove extra classes in favor of helper methods - /** + * An index that inserts a new dimension of size 1 into the resulting array. * - * @return + * @return index */ public static TensorIndex newAxis(){ return NewAxis.INSTANCE; } + /** + * An index that expands to fill all available source dimensions. + * Works the same as Python's {@code ...}. + * @see #expand() + * @return index + */ public static TensorIndex ellipsis(){ return Ellipsis.INSTANCE; } + /** + * An index that expands to fill all available source dimensions. + * Works the same as Python's {@code ...}. + * + * @return index + */ public static TensorIndex expand(){ return ellipsis(); } + /** + * An index that returns elements between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ public static TensorIndex slice(Long start, Long end){ return slice(start, end, 1); } - public static TensorIndex slice(Long start, Long end, long stride){ - return new Slice(start, end, stride); + /** + * An index that returns elements between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(long start, Long end){ + return slice(start, end, 1); } - public static TensorIndex slice(Integer start, int end){ - return intSlice(start, end, 1); + /** + * An index that returns elements between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(Long start, long end){ + return slice(start, end, 1); } - public static TensorIndex slice(int start, Integer end){ - return intSlice(start, end, 1); + /** + * An index that returns elements between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(long start, long end){ + return slice(start, end, 1); } - public static TensorIndex slice(Integer start, int end, long stride){ - return intSlice(start, end, stride); + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(Long start, Long end, long stride){ + return new Slice(start, end, stride); + } + + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(long start, Long end, long stride){ + return new Slice(start, end, stride); } - public static TensorIndex slice(int start, Integer end, long stride){ - return intSlice(start, end, stride); + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(Long start, long end, long stride){ + return new Slice(start, end, stride); } - private static TensorIndex intSlice(Integer start, Integer end, long stride){ - return new Slice(start == null ? null : start.longValue(), end == null ? null : end.longValue(), stride); + /** + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. + * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + *

+ * Analogous to Python's {@code :} slice syntax. + * + * @return index + */ + public static TensorIndex slice(long start, long end, long stride){ + return new Slice(start, end, stride); } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java deleted file mode 100644 index 070331f1ffb..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Odd.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Odd implements Index { - - static final Odd INSTANCE = new Odd(); - - @Override - public long numElements(Dimension dim) { - return dim.numElements() >> 1; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return (coordinate << 1) + 1; - } - - private Odd() { - } -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java deleted file mode 100644 index e5d6003d87b..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Range.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Range implements Index { - - @Override - public long numElements(Dimension dim) { - return end - start; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return start + coordinate; - } - - Range(long start, long end) { - this.start = start; - this.end = end; - } - - private final long start; - private final long end; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index a1082ebb887..1ed7a166b11 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -65,9 +65,6 @@ private long end(Dimension dim){ @Override public long numElements(Dimension dim) { long length = end(dim) - start(dim); - if(stride < 0){ - length *= -1; - } return (length / stride) + (length % stride != 0 ? 1 : 0); } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java deleted file mode 100644 index 725abd8f2e7..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class Step implements Index { - - @Override - public long numElements(Dimension dim) { - return (dim.numElements() / stepLength) + 1; // FIXME always include element 0? - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate * stepLength; - } - - Step(long stepLength) { - this.stepLength = stepLength; - } - - private final long stepLength; -} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java deleted file mode 100644 index 167d1c6865e..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/To.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - Copyright 2019 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ======================================================================= - */ -package org.tensorflow.ndarray.index; - -import org.tensorflow.ndarray.impl.dimension.Dimension; - -final class To implements Index { - - @Override - public long numElements(Dimension dim) { - return end; - } - - @Override - public long mapCoordinate(long coordinate, Dimension dim) { - return coordinate; - } - - To(long end) { - this.end = end; - } - - private final long end; -} From f496aee8264120fa7f7cc5688f541bbb937caf6e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 4 Jan 2021 21:05:39 -0800 Subject: [PATCH 16/24] Tests Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/IndexTest.java | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java index a0590b25cee..70c6f21d30c 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import org.junit.jupiter.api.Test; @@ -42,4 +43,163 @@ public void testNullConversions(){ assertTrue(Indices.slice(null, null).endMask(), "Passed null for slice end but didn't set end mask"); } + + @Test + public void testNewaxis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis()); + + assertEquals(Shape.of(5, 4, 5, 1), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1, 0)); + assertEquals(4, slice1.getInt(0, 0, 4, 0)); + assertEquals(2, slice1.getInt(0, 1, 2, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all()); + + assertEquals(Shape.of(5, 4, 1, 5), slice2.shape()); + assertEquals(0, slice2.getInt(0, 0, 0, 0)); + assertEquals(1, slice2.getInt(0, 0, 0, 1)); + assertEquals(4, slice2.getInt(0, 0, 0, 4)); + assertEquals(2, slice2.getInt(0, 1, 0, 2)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(5, 1, 4, 5), slice3.shape()); + assertEquals(0, slice3.getInt(0, 0, 0, 0)); + assertEquals(1, slice3.getInt(0, 0, 0, 1)); + assertEquals(4, slice3.getInt(0, 0, 0, 4)); + assertEquals(2, slice3.getInt(0, 0, 1, 2)); + + IntNdArray slice4 = matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all()); + + assertEquals(Shape.of(1, 5, 4, 5), slice4.shape()); + assertEquals(0, slice4.getInt(0, 0, 0, 0)); + assertEquals(1, slice4.getInt(0, 0, 0, 1)); + assertEquals(4, slice4.getInt(0, 0, 0, 4)); + assertEquals(2, slice4.getInt(0, 0, 1, 2)); + + } + + @Test + public void testEllipsis(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()), + matrix3d.slice(Indices.at(0), Indices.ellipsis()) + ); + + assertEquals( + matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0)) + ); + + // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)), + matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)), + matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0)) + ); + + assertEquals( + matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()), + matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis()) + ); + } + + @Test + public void testSlice(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.slice(null, 3), Indices.all()); + + assertEquals(Shape.of(5, 3, 5), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0, 0)); + assertEquals(1, slice1.getInt(0, 0, 1)); + assertEquals(2, slice1.getInt(0, 1, 2)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4)); + + assertEquals(Shape.of(5, 4, 3), slice2.shape()); + assertEquals(1, slice2.getInt(0, 0, 0)); + assertEquals(3, slice2.getInt(0, 0, 2)); + assertEquals(2, slice2.getInt(0, 1, 1)); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1))); + + assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1))); + + assertEquals(Shape.of(5, 4, 0), matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape()); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2)); + + assertEquals(Shape.of(5, 4, 2), slice3.shape()); + assertEquals(4, slice3.getInt(0, 0, 0)); + assertEquals(2, slice3.getInt(0, 1, 1)); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2))); + + assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2))); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1)); + + assertEquals(Shape.of(5, 4, 5), slice4.shape()); + assertEquals(4, slice4.getInt(0, 0, 0)); + assertEquals(3, slice4.getInt(0, 0, 1)); + assertEquals(2, slice4.getInt(0, 1, 2)); + } + + @Test + public void testAt(){ + IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5)); + + matrix3d.scalars().forEachIndexed((coords, scalar) -> + scalar.setInt((int)coords[2]) + ); + + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)); + + assertEquals(Shape.of(5, 4), slice1.shape()); + assertEquals(0, slice1.getInt(0, 0)); + + IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3)); + + assertEquals(Shape.of(5, 4), slice2.shape()); + assertEquals(3, slice2.getInt(0, 0)); + + IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3)); + + assertEquals(Shape.of(5, 4), slice3.shape()); + assertEquals(2, slice3.getInt(0, 0)); + + IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true)); + + assertEquals(Shape.of(5, 4, 1), slice4.shape()); + assertEquals(2, slice4.getInt(0, 0, 0)); + } + } From 4f58b193fa124b9d67a3d76c79633b6820f224a4 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 4 Jan 2021 21:10:38 -0800 Subject: [PATCH 17/24] ToString methods Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/ndarray/index/All.java | 5 +++++ .../main/java/org/tensorflow/ndarray/index/At.java | 9 +++++++++ .../java/org/tensorflow/ndarray/index/Ellipsis.java | 6 ++++++ .../java/org/tensorflow/ndarray/index/Hyperslab.java | 11 +++++++++++ .../java/org/tensorflow/ndarray/index/NewAxis.java | 5 +++++ .../java/org/tensorflow/ndarray/index/Sequence.java | 8 ++++++++ .../main/java/org/tensorflow/ndarray/index/Slice.java | 10 ++++++++++ 7 files changed, 54 insertions(+) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index 1efd6bafe53..b6a57b124c1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -49,4 +49,9 @@ public boolean beginMask() { public boolean endMask() { return true; } + + @Override + public String toString() { + return "All()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 3fbdab82095..43596beb6f3 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class At implements TensorIndex { @@ -68,4 +69,12 @@ public long end() { public boolean shrinkAxisMask() { return !keepDim; } + + @Override + public String toString() { + return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") + .add("coord=" + coord) + .add("keepDim=" + keepDim) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index f3e9247431d..6b05c5a7a95 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class Ellipsis implements TensorIndex{ @@ -45,4 +46,9 @@ public boolean isEllipsis() { public boolean ellipsisMask() { return true; } + + @Override + public String toString() { + return "Ellipsis()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 00b411d0167..8131eb8b2dc 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -15,6 +15,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; /** @@ -71,4 +72,14 @@ public boolean isPoint() { private final long stride; private final long count; private final long block; + + @Override + public String toString() { + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("stride=" + stride) + .add("count=" + count) + .add("block=" + block) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index f29f7058732..278fb735b45 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -50,4 +50,9 @@ public boolean isNewAxis() { public boolean newAxisMask() { return true; } + + @Override + public String toString() { + return "NewAxis()"; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java index 41d37d05806..66de7bbbb52 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.impl.dimension.Dimension; @@ -36,4 +37,11 @@ public long mapCoordinate(long coordinate, Dimension dim) { } private final NdArray coords; + + @Override + public String toString() { + return new StringJoiner(", ", Sequence.class.getSimpleName() + "(", ")") + .add("coords=" + coords) + .toString(); + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index 1ed7a166b11..98f8fe1b7e4 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -16,6 +16,7 @@ */ package org.tensorflow.ndarray.index; +import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; final class Slice implements TensorIndex { @@ -98,4 +99,13 @@ public boolean beginMask() { public boolean endMask() { return end == null; } + + @Override + public String toString() { + return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } } From 95909616d49f766ec66ef01f6fed8bb27f2cd508 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 7 Jan 2021 16:34:48 -0800 Subject: [PATCH 18/24] Cleanup and formatting Signed-off-by: Ryan Nett --- .../impl/dimension/DimensionalSpace.java | 22 ++-- .../org/tensorflow/ndarray/index/All.java | 2 +- .../java/org/tensorflow/ndarray/index/At.java | 11 +- .../tensorflow/ndarray/index/Ellipsis.java | 10 +- .../tensorflow/ndarray/index/Hyperslab.java | 2 +- .../org/tensorflow/ndarray/index/NewAxis.java | 4 +- .../org/tensorflow/ndarray/index/Slice.java | 66 +++++------ .../op/core/StridedSliceHelper.java | 106 +++++++++--------- 8 files changed, 109 insertions(+), 114 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java index 55440b7d719..7d0f0222bbe 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/impl/dimension/DimensionalSpace.java @@ -18,8 +18,8 @@ package org.tensorflow.ndarray.impl.dimension; import java.util.Arrays; -import org.tensorflow.ndarray.index.Index; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.index.Index; public class DimensionalSpace { @@ -46,12 +46,12 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { int newAxes = 0; boolean seenEllipsis = false; - for(Index idx : indices){ - if(idx.isNewAxis()){ + for (Index idx : indices) { + if (idx.isNewAxis()) { newAxes += 1; } - if(idx.isEllipsis()){ - if(seenEllipsis){ + if (idx.isEllipsis()) { + if (seenEllipsis) { throw new IllegalArgumentException("Only one ellipsis allowed"); } else { seenEllipsis = true; @@ -83,9 +83,9 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx - 1; } - } else if(indices[indexIdx].isNewAxis()) { + } else if (indices[indexIdx].isNewAxis()) { long newSize; - if(dimIdx == 0){ + if (dimIdx == 0) { // includes everything. Should really include future reduction (at()) but that doesn't seem to cause issues // elsewhere newSize = dimensions[0].numElements() * dimensions[0].elementSize(); @@ -97,16 +97,16 @@ public RelativeDimensionalSpace mapTo(Index[] indices) { segmentationIdx = newDimIdx; // is this correct? ++newDimIdx; ++indexIdx; - } else if(indices[indexIdx].isEllipsis()){ + } else if (indices[indexIdx].isEllipsis()) { int remainingDimensions = dimensions.length - dimIdx; int requiredDimensions = 0; - for(int i = indexIdx + 1 ; i < indices.length ; i++){ - if(!indices[i].isNewAxis()){ + for (int i = indexIdx + 1; i < indices.length; i++) { + if (!indices[i].isNewAxis()) { requiredDimensions++; } } // while the number of dimensions left < the number of indices that consume axes - while(remainingDimensions > requiredDimensions){ + while (remainingDimensions > requiredDimensions) { Dimension dim = dimensions[dimIdx++]; if (dim.isSegmented()) { segmentationIdx = newDimIdx; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index b6a57b124c1..43e4569169d 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -52,6 +52,6 @@ public boolean endMask() { @Override public String toString() { - return "All()"; + return All.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 43596beb6f3..54566c058e2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -29,17 +29,16 @@ public long numElements(Dimension dim) { @Override public long mapCoordinate(long coordinate, Dimension dim) { long coord = this.coord >= 0 ? this.coord : dim.numElements() + this.coord; - return dim.positionOf(coord); // TODO validate coordinate is 0? + return dim.positionOf(coord); } @Override public Dimension apply(Dimension dim) { - if(keepDim){ - return dim.withIndex(this); - } - else { - throw new IllegalStateException(); // FIXME? + if (!keepDim) { + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } + + return dim.withIndex(this); } @Override diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 6b05c5a7a95..5577660a2ab 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -19,22 +19,22 @@ import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Ellipsis implements TensorIndex{ +final class Ellipsis implements TensorIndex { static final Ellipsis INSTANCE = new Ellipsis(); - private Ellipsis(){ + private Ellipsis() { } @Override public long numElements(Dimension dim) { - throw new IllegalStateException(); + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } @Override public long mapCoordinate(long coordinate, Dimension dim) { - throw new IllegalStateException(); + throw new UnsupportedOperationException("Should be handled in DimensionalSpace."); } @Override @@ -49,6 +49,6 @@ public boolean ellipsisMask() { @Override public String toString() { - return "Ellipsis()"; + return Ellipsis.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 8131eb8b2dc..632438eb4f2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -75,7 +75,7 @@ public boolean isPoint() { @Override public String toString() { - return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "(", ")") + return new StringJoiner(", ", Hyperslab.class.getSimpleName() + "Hyperslab(", ")") .add("start=" + start) .add("stride=" + stride) .add("count=" + count) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index 278fb735b45..31ad51d16d7 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -22,7 +22,7 @@ final class NewAxis implements TensorIndex { static final NewAxis INSTANCE = new NewAxis(); - private NewAxis(){ + private NewAxis() { } @@ -53,6 +53,6 @@ public boolean newAxisMask() { @Override public String toString() { - return "NewAxis()"; + return NewAxis.class.getSimpleName() + "()"; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index 98f8fe1b7e4..c8ec1773132 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -21,44 +21,12 @@ final class Slice implements TensorIndex { - private final Long start; - private final Long end; - private final long stride; - - private long start(Dimension dim){ - if(start == null){ - if(stride > 0){ - return 0; - } else { - return dim.numElements() - 1; // it's inclusive - } - } else if(start < 0){ - return dim.numElements() + start; - } else { - return start; - } - } - - private long end(Dimension dim){ - if(end == null){ - if(stride > 0){ - return dim.numElements(); - } else { - return -1; // it's exclusive - } - } else if(end < 0){ - return dim.numElements() + end; - } else { - return end; - } - } - Slice(Long start, Long end, long stride) { this.start = start; this.end = end; this.stride = stride; - if(stride == 0){ + if (stride == 0) { throw new IllegalArgumentException("Can not have a stride of 0"); } } @@ -108,4 +76,36 @@ public String toString() { .add("stride=" + stride) .toString(); } + + private long start(Dimension dim) { + if (start == null) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } else if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (end == null) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } else if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final Long start; + private final Long end; + private final long stride; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index 37111979d3d..8a408ff1980 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -71,16 +71,21 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { } begin[i] = (int) idx.begin(); - if(begin[i] != idx.begin()) - throw new IllegalArgumentException("Can't convert long begin value to int for index " + idx + ": Out of bounds"); + if (begin[i] != idx.begin()) { + throw new IllegalArgumentException( + "Can't convert long begin value to int for index " + idx + ": Out of bounds"); + } end[i] = (int) idx.end(); - if(end[i] != idx.end()) + if (end[i] != idx.end()) { throw new IllegalArgumentException("Can't convert long end value to int for index " + idx + ": Out of bounds"); + } strides[i] = (int) idx.stride(); - if(strides[i] != idx.stride()) - throw new IllegalArgumentException("Can't convert long stride value to int for index " + idx + ": Out of bounds"); + if (strides[i] != idx.stride()) { + throw new IllegalArgumentException( + "Can't convert long stride value to int for index " + idx + ": Out of bounds"); + } if (idx.beginMask()) { beginMask |= 1L << i; @@ -91,8 +96,9 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { } if (idx.ellipsisMask()) { - if(ellipsisMask != 0) + if (ellipsisMask != 0) { throw new IllegalArgumentException("Can not have two ellipsis in a slice"); + } ellipsisMask |= 1L << i; } @@ -107,53 +113,43 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { return new StridedSliceArgs(begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); } + /** * Return a strided slice from `input`. - *

- * The goal of this op is to produce a new tensor with a subset of - * the elements from the `n` dimensional `input` tensor. The subset is chosen using - * a sequence of `m` sparse range specifications encoded into the arguments - * of this function. Note, in some cases - * `m` could be equal to `n`, but this need not be the case. Each - * range specification entry can be one of the following: - *

- * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more - * dimensions of full-dimension selection and are produced using - * `ellipsis_mask`. For example, `foo[...]` is the identity slice. - *

- * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is - * produced using `new_axis_mask`. For example, `foo[:, ...]` where - * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. - *

- * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify how much to choose from - * a given dimension. `stride` can be any integer but 0. `begin` is an integer - * which represents the index of the first value to select while `end` represents - * the index of the last value to select. The number of values selected in each - * dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. - * `begin` and `end` can be negative where `-1` is the last element, `-2` is - * the second to last. `begin_mask` controls whether to replace the explicitly - * given `begin` with an implicit effective value of `0` if `stride > 0` and - * `-1` if `stride < 0`. `end_mask` is analogous but produces the number - * required to create the largest open interval. For example, given a shape - * `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do - * not assume this is equivalent to `foo[0:-1]` which has an effective `begin` - * and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the - * first dimension of a tensor while dropping the last two (in the original - * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. - *

- * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given - * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a - * shape `(6,)` tensor. This is encoded in `begin` and `end` and - * `shrink_axis_mask`. - *

+ *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection and are produced using `ellipsis_mask`. For example, `foo[...]` is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is produced using + * `new_axis_mask`. For example, `foo[:, ...]` where `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + *

+ * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify + * how much to choose from a given dimension. `stride` can be any integer but 0. `begin` is an integer which + * represents the index of the first value to select while `end` represents the index of the last value to select. The + * number of values selected in each dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. + * `begin` and `end` can be negative where `-1` is the last element, `-2` is the second to last. `begin_mask` controls + * whether to replace the explicitly given `begin` with an implicit effective value of `0` if `stride > 0` and `-1` if + * `stride < 0`. `end_mask` is analogous but produces the number required to create the largest open interval. For + * example, given a shape `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do not assume this + * is equivalent to `foo[0:-1]` which has an effective `begin` and `end` of `0` and `2`. Another example is + * `foo[-2::-1]` which reverses the first dimension of a tensor while dropping the last two (in the original order + * elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example (`foo[2, :]` on a shape `(5,6)` tensor produces a shape `(6,)` tensor. This is encoded in `begin` and `end` + * and `shrink_axis_mask`. + *

* - * Requirements: - * `0 != strides[i] for i in [0, m)` - * Only one ellipsis. + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. * * @param scope current scope * @param data type for {@code output()} output - * @param input * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSlice * @see Indices @@ -177,13 +173,12 @@ public static StridedSlice stridedSlice(Scope scope, Operan /** * Assign `value` to the sliced l-value reference of `ref`. - *

- * The values of `value` are assigned to the positions in the variable - * `ref` that are selected by the slice parameters. The slice parameters - * `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. - *

- * NOTE this op currently does not support broadcasting and so `value`'s - * shape must be exactly the shape produced by the slice of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. * * @param data type for {@code outputRef()} output * @param scope current scope @@ -194,7 +189,8 @@ public static StridedSlice stridedSlice(Scope scope, Operan * @see org.tensorflow.op.Ops#stridedSlice(Operand, TensorIndex...) */ @Endpoint(name = "stridedSliceAssign") - public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, Operand value, TensorIndex... indices) { + public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, + Operand value, TensorIndex... indices) { StridedSliceArgs args = mergeIndexes(indices); return StridedSliceAssign.create( scope, From 090cfb9c3a987426c5d641fc955cf19427442d87 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 7 Jan 2021 16:45:11 -0800 Subject: [PATCH 19/24] Cleanup and formatting Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/index/All.java | 2 +- .../java/org/tensorflow/ndarray/index/At.java | 2 +- .../tensorflow/ndarray/index/Ellipsis.java | 3 +- .../tensorflow/ndarray/index/Hyperslab.java | 5 ++ .../org/tensorflow/ndarray/index/Index.java | 60 +++++++++++++++---- .../org/tensorflow/ndarray/index/Indices.java | 46 +++++++------- .../org/tensorflow/ndarray/index/NewAxis.java | 2 +- .../tensorflow/ndarray/index/Sequence.java | 5 ++ .../org/tensorflow/ndarray/index/Slice.java | 2 +- .../tensorflow/ndarray/index/TensorIndex.java | 50 ---------------- .../annotations/org/tensorflow/op/Ops.java | 8 +-- .../op/core/StridedSliceHelper.java | 23 ++++--- .../org/tensorflow/op/core/IndexingTest.java | 5 +- 13 files changed, 107 insertions(+), 106 deletions(-) delete mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java index 43e4569169d..9d3139f3248 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/All.java @@ -18,7 +18,7 @@ import org.tensorflow.ndarray.impl.dimension.Dimension; -final class All implements TensorIndex { +final class All implements Index { static final All INSTANCE = new All(); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index 54566c058e2..ebd10dc6c32 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -19,7 +19,7 @@ import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class At implements TensorIndex { +final class At implements Index { @Override public long numElements(Dimension dim) { diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 5577660a2ab..0ab9e2ec883 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -16,10 +16,9 @@ */ package org.tensorflow.ndarray.index; -import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Ellipsis implements TensorIndex { +final class Ellipsis implements Index { static final Ellipsis INSTANCE = new Ellipsis(); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index 632438eb4f2..dc8eb6244dd 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -82,4 +82,9 @@ public String toString() { .add("block=" + block) .toString(); } + + @Override + public boolean tensorSupport() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index b459ff7a99d..028405ce243 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -23,19 +23,16 @@ * An index used for slicing a view out of an N-dimensional array. * *

A slice, i.e. a reduced view, of an N-dimensional array is obtain by calling - * {@link NdArray#slice(Index...)}, given a list of indices - * that select which elements on a given dimension should be included/excluded - * from that view. + * {@link NdArray#slice(Index...)}, given a list of indices that select which elements on a given dimension should be + * included/excluded from that view. */ public interface Index { /** - * Returns the number of elements that can be retrieved using this index on the - * given dimension. + * Returns the number of elements that can be retrieved using this index on the given dimension. * *

An index that maps one-by-one all elements of the dimensions will return a value - * equal to {@code dim.numElements()}, while an index that only maps a subset of these - * will return a smaller value. + * equal to {@code dim.numElements()}, while an index that only maps a subset of these will return a smaller value. * * @param dim the indexed dimension * @return number of elements accessible @@ -43,8 +40,7 @@ public interface Index { long numElements(Dimension dim); /** - * Transforms an element coordinate to a new coordinate by applying this index to the - * given dimension. + * Transforms an element coordinate to a new coordinate by applying this index to the given dimension. * *

For example, if the coordinate is 0 and this index flips the {@code n} elements on this * dimension, then the returned value will be {@code n-1}. @@ -83,10 +79,52 @@ default boolean isNewAxis() { } /** - * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible - * (and applying all() to them) + * Returns true if this index is an ellipsis, expanding to take as many dimensions as possible (and applying all() to + * them) */ default boolean isEllipsis() { return false; } + + /** + * Get whether the Index supports Tensor conversion. + */ + default boolean tensorSupport() { + return true; + } + + /** + * Get the start of the index, for Tensor conversion. + */ + default long begin() { + return 0; + } + + /** + * Get the end of the index, for Tensor conversion. + */ + default long end() { + return 0; + } + + /** + * Get the stride of the index, for Tensor conversion. + */ + default long stride() { + return 1; + } + + /** + * Get whether the Index should start at the beginning of the dimension, for Tensor conversion. + */ + default boolean beginMask() { + return false; + } + + /** + * Get whether the Index should end at the beginning of the dimension, for Tensor conversion. + */ + default boolean endMask() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index 42018dc9ca3..ac73a6aa25b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -40,7 +40,7 @@ public final class Indices { * @param coord coordinate of the element on the indexed axis * @return index */ - public static TensorIndex at(long coord) { + public static Index at(long coord) { return new At(coord, false); } @@ -54,7 +54,7 @@ public static TensorIndex at(long coord) { * @return index * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) */ - public static TensorIndex at(NdArray coord) { + public static Index at(NdArray coord) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } @@ -76,7 +76,7 @@ public static TensorIndex at(NdArray coord) { * @param keepDim whether to remove the dimension. * @return index */ - public static TensorIndex at(long coord, boolean keepDim) { + public static Index at(long coord, boolean keepDim) { return new At(coord, keepDim); } @@ -93,7 +93,7 @@ public static TensorIndex at(long coord, boolean keepDim) { * @param keepDim whether to remove the dimension. * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) */ - public static TensorIndex at(NdArray coord, boolean keepDim) { + public static Index at(NdArray coord, boolean keepDim) { if (coord.rank() > 0) { throw new IllegalRankException("Only scalars are accepted as a value index"); } @@ -111,7 +111,7 @@ public static TensorIndex at(NdArray coord, boolean keepDim) { * * @return index */ - public static TensorIndex all() { + public static Index all() { return All.INSTANCE; } @@ -157,7 +157,7 @@ public static Index seq(NdArray coords) { * * @return index */ - public static TensorIndex even() { + public static Index even() { return slice(null, null, 2); } @@ -170,7 +170,7 @@ public static TensorIndex even() { * * @return index */ - public static TensorIndex odd() { + public static Index odd() { return slice(1, null, 2); } @@ -183,7 +183,7 @@ public static TensorIndex odd() { * @param stepLength the number of elements between each steps * @return index */ - public static TensorIndex step(long stepLength) { + public static Index step(long stepLength) { return slice(null, null, stepLength); } @@ -197,7 +197,7 @@ public static TensorIndex step(long stepLength) { * @param start coordinate of the first element of the sequence * @return index */ - public static TensorIndex from(long start) { + public static Index from(long start) { return slice(start, null); } @@ -211,7 +211,7 @@ public static TensorIndex from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static TensorIndex to(long end) { + public static Index to(long end) { return slice(null, end); } @@ -225,7 +225,7 @@ public static TensorIndex to(long end) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static TensorIndex range(long start, long end) { + public static Index range(long start, long end) { return slice(start, end); } @@ -237,7 +237,7 @@ public static TensorIndex range(long start, long end) { * * @return index */ - public static TensorIndex flip() { + public static Index flip() { return slice(null, null, -1); } @@ -261,7 +261,7 @@ public static Index hyperslab(long start, long stride, long count, long block) { * * @return index */ - public static TensorIndex newAxis(){ + public static Index newAxis(){ return NewAxis.INSTANCE; } @@ -271,7 +271,7 @@ public static TensorIndex newAxis(){ * @see #expand() * @return index */ - public static TensorIndex ellipsis(){ + public static Index ellipsis(){ return Ellipsis.INSTANCE; } @@ -281,7 +281,7 @@ public static TensorIndex ellipsis(){ * * @return index */ - public static TensorIndex expand(){ + public static Index expand(){ return ellipsis(); } @@ -293,7 +293,7 @@ public static TensorIndex expand(){ * * @return index */ - public static TensorIndex slice(Long start, Long end){ + public static Index slice(Long start, Long end){ return slice(start, end, 1); } @@ -305,7 +305,7 @@ public static TensorIndex slice(Long start, Long end){ * * @return index */ - public static TensorIndex slice(long start, Long end){ + public static Index slice(long start, Long end){ return slice(start, end, 1); } @@ -317,7 +317,7 @@ public static TensorIndex slice(long start, Long end){ * * @return index */ - public static TensorIndex slice(Long start, long end){ + public static Index slice(Long start, long end){ return slice(start, end, 1); } @@ -329,7 +329,7 @@ public static TensorIndex slice(Long start, long end){ * * @return index */ - public static TensorIndex slice(long start, long end){ + public static Index slice(long start, long end){ return slice(start, end, 1); } @@ -341,7 +341,7 @@ public static TensorIndex slice(long start, long end){ * * @return index */ - public static TensorIndex slice(Long start, Long end, long stride){ + public static Index slice(Long start, Long end, long stride){ return new Slice(start, end, stride); } @@ -353,7 +353,7 @@ public static TensorIndex slice(Long start, Long end, long stride){ * * @return index */ - public static TensorIndex slice(long start, Long end, long stride){ + public static Index slice(long start, Long end, long stride){ return new Slice(start, end, stride); } @@ -365,7 +365,7 @@ public static TensorIndex slice(long start, Long end, long stride){ * * @return index */ - public static TensorIndex slice(Long start, long end, long stride){ + public static Index slice(Long start, long end, long stride){ return new Slice(start, end, stride); } @@ -377,7 +377,7 @@ public static TensorIndex slice(Long start, long end, long stride){ * * @return index */ - public static TensorIndex slice(long start, long end, long stride){ + public static Index slice(long start, long end, long stride){ return new Slice(start, end, stride); } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index 31ad51d16d7..7c9be8fac56 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -18,7 +18,7 @@ import org.tensorflow.ndarray.impl.dimension.Dimension; -final class NewAxis implements TensorIndex { +final class NewAxis implements Index { static final NewAxis INSTANCE = new NewAxis(); diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java index 66de7bbbb52..6b8951ba2d1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -44,4 +44,9 @@ public String toString() { .add("coords=" + coords) .toString(); } + + @Override + public boolean tensorSupport() { + return false; + } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index c8ec1773132..12ce613842a 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -19,7 +19,7 @@ import java.util.StringJoiner; import org.tensorflow.ndarray.impl.dimension.Dimension; -final class Slice implements TensorIndex { +final class Slice implements Index { Slice(Long start, Long end, long stride) { this.start = start; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java deleted file mode 100644 index c66dd18de10..00000000000 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/TensorIndex.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - Copyright 2020 The TensorFlow Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ -package org.tensorflow.ndarray.index; - -public interface TensorIndex extends Index{ - default long begin(){ - return 0; - } - default long end(){ - return 0; - } - - default long stride(){ - return 1; - } - - default boolean beginMask(){ - return false; - } - - default boolean endMask(){ - return false; - } - - default boolean ellipsisMask(){ - return false; - } - - default boolean newAxisMask(){ - return false; - } - - default boolean shrinkAxisMask(){ - return false; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 0a95dbce72a..efe1ff30373 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -38,7 +38,7 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.ndarray.index.TensorIndex; +import org.tensorflow.ndarray.index.Index; import org.tensorflow.op.core.Abort; import org.tensorflow.op.core.All; import org.tensorflow.op.core.Any; @@ -5953,7 +5953,7 @@ public StopGradient stopGradient(Operand input) { * @return a new instance of StridedSlice * @see Indices */ - public StridedSlice stridedSlice(Operand input, TensorIndex... indices) { + public StridedSlice stridedSlice(Operand input, Index... indices) { return StridedSliceHelper.stridedSlice(scope, input, indices); } @@ -6085,10 +6085,10 @@ public StridedSlice stridedSlice(Operand * @param value the value to assign. * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSliceAssign - * @see org.tensorflow.op.Ops#stridedSlice(Operand, TensorIndex...) + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) */ public StridedSliceAssign stridedSliceAssign(Operand ref, - Operand value, TensorIndex... indices) { + Operand value, Index... indices) { return StridedSliceHelper.stridedSliceAssign(scope, ref, value, indices); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index 8a408ff1980..335b2c7825e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.index.Indices; -import org.tensorflow.ndarray.index.TensorIndex; +import org.tensorflow.ndarray.index.Index; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; @@ -54,7 +54,7 @@ private StridedSliceArgs(int[] begin, int[] end, int[] strides, long beginMask, } } - static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { + static StridedSliceArgs mergeIndexes(Index[] indices) { int[] begin = new int[indices.length]; int[] end = new int[indices.length]; int[] strides = new int[indices.length]; @@ -65,11 +65,16 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { long shrinkAxisMask = 0; for (int i = 0; i < indices.length; i++) { - TensorIndex idx = indices[i]; + Index idx = indices[i]; + if (idx == null) { idx = Indices.all(); } + if (!idx.tensorSupport()) { + throw new UnsupportedOperationException("Index " + idx + " is not supported for Tensors"); + } + begin[i] = (int) idx.begin(); if (begin[i] != idx.begin()) { throw new IllegalArgumentException( @@ -95,18 +100,18 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { endMask |= 1L << i; } - if (idx.ellipsisMask()) { + if (idx.isEllipsis()) { if (ellipsisMask != 0) { throw new IllegalArgumentException("Can not have two ellipsis in a slice"); } ellipsisMask |= 1L << i; } - if (idx.newAxisMask()) { + if (idx.isNewAxis()) { newAxisMask |= 1L << i; } - if (idx.shrinkAxisMask()) { + if (idx.isPoint()) { shrinkAxisMask |= 1L << i; } } @@ -155,7 +160,7 @@ static StridedSliceArgs mergeIndexes(TensorIndex[] indices) { * @see Indices */ @Endpoint(name = "stridedSlice") - public static StridedSlice stridedSlice(Scope scope, Operand input, TensorIndex... indices) { + public static StridedSlice stridedSlice(Scope scope, Operand input, Index... indices) { StridedSliceArgs args = mergeIndexes(indices); return StridedSlice.create( scope, @@ -186,11 +191,11 @@ public static StridedSlice stridedSlice(Scope scope, Operan * @param value the value to assign. * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSliceAssign - * @see org.tensorflow.op.Ops#stridedSlice(Operand, TensorIndex...) + * @see org.tensorflow.op.Ops#stridedSlice(Operand, Index...) */ @Endpoint(name = "stridedSliceAssign") public static StridedSliceAssign stridedSliceAssign(Scope scope, Operand ref, - Operand value, TensorIndex... indices) { + Operand value, Index... indices) { StridedSliceArgs args = mergeIndexes(indices); return StridedSliceAssign.create( scope, diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index 259a8d66d24..a4fcba2caec 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -20,17 +20,16 @@ import org.junit.Test; import org.tensorflow.Graph; import org.tensorflow.Session; -import org.tensorflow.Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; -import org.tensorflow.ndarray.index.TensorIndex; +import org.tensorflow.ndarray.index.Index; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; public class IndexingTest { // [2, 1:2, :, tf.newaxis, ..., :4, 4::2] - private static final TensorIndex[] slice = new TensorIndex[]{ + private static final Index[] slice = new Index[]{ Indices.at(2), Indices.at(1, true), Indices.all(), From 14e3c549fa6037f1ff934050feb6099b6f75772d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 7 Jan 2021 17:28:46 -0800 Subject: [PATCH 20/24] Javadocs cleanup, new names Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ndarray/index/At.java | 5 - .../tensorflow/ndarray/index/Ellipsis.java | 5 - .../org/tensorflow/ndarray/index/Indices.java | 36 +++++++- .../org/tensorflow/ndarray/index/NewAxis.java | 5 - .../tensorflow/ndarray/NdArrayTestBase.java | 8 +- .../ndarray/impl/dense/DenseNdArrayTest.java | 2 +- .../annotations/org/tensorflow/op/Ops.java | 92 +++++++++---------- .../op/core/StridedSliceHelper.java | 38 ++++---- 8 files changed, 104 insertions(+), 87 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java index ebd10dc6c32..31ce021ddc8 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/At.java @@ -64,11 +64,6 @@ public long end() { return coord + 1; } - @Override - public boolean shrinkAxisMask() { - return !keepDim; - } - @Override public String toString() { return new StringJoiner(", ", At.class.getSimpleName() + "(", ")") diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java index 0ab9e2ec883..d4085735df2 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Ellipsis.java @@ -41,11 +41,6 @@ public boolean isEllipsis() { return true; } - @Override - public boolean ellipsisMask() { - return true; - } - @Override public String toString() { return Ellipsis.class.getSimpleName() + "()"; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index ac73a6aa25b..425e8797ce6 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -197,7 +197,7 @@ public static Index step(long stepLength) { * @param start coordinate of the first element of the sequence * @return index */ - public static Index from(long start) { + public static Index sliceFrom(long start) { return slice(start, null); } @@ -211,10 +211,42 @@ public static Index from(long start) { * @param end coordinate of the last element of the sequence (exclusive) * @return index */ - public static Index to(long end) { + public static Index sliceTo(long end) { return slice(null, end); } + /** + * An index that returns only elements on a given dimension starting at a + * specific coordinate, using the given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code from(k)} returns xk, xk+1, ..., xn-1 + * + * @param start coordinate of the first element of the sequence + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceFrom(long start, long stride) { + return slice(start, null, stride); + } + + /** + * An index that returns only elements on a given dimension up to a + * specific coordinate, using the given stride. + * + *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, + * {@code to(k)} returns x0, x1, ..., xk + * + * @param end coordinate of the last element of the sequence (exclusive) + * @param stride the stride to use + * @return index + * @see #slice(long, long, long) + */ + public static Index sliceTo(long end, long stride) { + return slice(null, end, stride); + } + /** * An index that returns only elements on a given dimension between two coordinates. * diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java index 7c9be8fac56..a68b1ed9ad1 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/NewAxis.java @@ -46,11 +46,6 @@ public boolean isNewAxis() { return true; } - @Override - public boolean newAxisMask() { - return true; - } - @Override public String toString() { return NewAxis.class.getSimpleName() + "()"; diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java index 1c1d89680e7..26ac533daa8 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/NdArrayTestBase.java @@ -24,11 +24,11 @@ import static org.tensorflow.ndarray.index.Indices.at; import static org.tensorflow.ndarray.index.Indices.even; import static org.tensorflow.ndarray.index.Indices.flip; -import static org.tensorflow.ndarray.index.Indices.from; +import static org.tensorflow.ndarray.index.Indices.sliceFrom; import static org.tensorflow.ndarray.index.Indices.odd; import static org.tensorflow.ndarray.index.Indices.range; import static org.tensorflow.ndarray.index.Indices.seq; -import static org.tensorflow.ndarray.index.Indices.to; +import static org.tensorflow.ndarray.index.Indices.sliceTo; import java.nio.BufferOverflowException; import java.nio.BufferUnderflowException; @@ -212,13 +212,13 @@ public void slices() { assertEquals(val101, vector10_flip.getObject(3)); // Vector (1,0,[from 1]) from vector (1,0,*) - NdArray vector10_1toX = vector10X.slice(from(1)); + NdArray vector10_1toX = vector10X.slice(sliceFrom(1)); assertEquals(vector10_1toX.shape(), Shape.of(4)); assertEquals(val101, vector10_1toX.getObject(0)); assertEquals(val102, vector10_1toX.getObject(1)); // Vector (1,0,[to 1]) from vector (1,0,*) - NdArray vector10_Xto1 = vector10X.slice(to(2)); + NdArray vector10_Xto1 = vector10X.slice(sliceTo(2)); assertEquals(vector10_Xto1.shape(), Shape.of(2)); assertEquals(val100, vector10_Xto1.getObject(0)); assertEquals(val101, vector10_Xto1.getObject(1)); diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java index d5b5ca809a4..375f7643875 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/impl/dense/DenseNdArrayTest.java @@ -40,7 +40,7 @@ public void equalsAndHashCodeOnSlices() { {{3, 4}, {6, 7}} }); - assertTrue(vector1.equals(vector2.slice(Indices.from(2)))); + assertTrue(vector1.equals(vector2.slice(Indices.sliceFrom(2)))); assertTrue(vector1.equals(matrix1.get(1))); assertTrue(vector1.equals(matrix2.get(1).slice(Indices.even()))); assertTrue(matrix1.equals(matrix2.slice(Indices.all(), Indices.even()))); diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index efe1ff30373..3cf293f759d 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -5904,51 +5904,46 @@ public StopGradient stopGradient(Operand input) { /** * Return a strided slice from `input`. - *

- * The goal of this op is to produce a new tensor with a subset of - * the elements from the `n` dimensional `input` tensor. The subset is chosen using - * a sequence of `m` sparse range specifications encoded into the arguments - * of this function. Note, in some cases - * `m` could be equal to `n`, but this need not be the case. Each - * range specification entry can be one of the following: - *

- * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more - * dimensions of full-dimension selection and are produced using - * `ellipsis_mask`. For example, `foo[...]` is the identity slice. - *

- * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is - * produced using `new_axis_mask`. For example, `foo[:, ...]` where - * `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. - *

- * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify how much to choose from - * a given dimension. `stride` can be any integer but 0. `begin` is an integer - * which represents the index of the first value to select while `end` represents - * the index of the last value to select. The number of values selected in each - * dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. - * `begin` and `end` can be negative where `-1` is the last element, `-2` is - * the second to last. `begin_mask` controls whether to replace the explicitly - * given `begin` with an implicit effective value of `0` if `stride > 0` and - * `-1` if `stride < 0`. `end_mask` is analogous but produces the number - * required to create the largest open interval. For example, given a shape - * `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do - * not assume this is equivalent to `foo[0:-1]` which has an effective `begin` - * and `end` of `0` and `2`. Another example is `foo[-2::-1]` which reverses the - * first dimension of a tensor while dropping the last two (in the original - * order elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. - *

- * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given - * index. For example (`foo[2, :]` on a shape `(5,6)` tensor produces a - * shape `(6,)` tensor. This is encoded in `begin` and `end` and - * `shrink_axis_mask`. - *

- * - * Requirements: - * `0 != strides[i] for i in [0, m)` - * Only one ellipsis. + *

+ * The goal of this op is to produce a new tensor with a subset of the elements from the `n` dimensional `input` + * tensor. The subset is chosen using a sequence of `m` sparse range specifications encoded into the arguments of this + * function. Note, in some cases `m` could be equal to `n`, but this need not be the case. Each range specification + * entry can be one of the following: + *

+ * - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. + *

+ * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. + *

+ * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. + *

+ * - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html + *

+ * + * Requirements: + * `0 != strides[i] for i in [0, m)` Only one ellipsis. * * @param scope current scope * @param data type for {@code output()} output - * @param input * @param indices The indices to slice. See {@link Indices}. * @return a new instance of StridedSlice * @see Indices @@ -6071,13 +6066,12 @@ public StridedSlice stridedSlice(Operand /** * Assign `value` to the sliced l-value reference of `ref`. - *

- * The values of `value` are assigned to the positions in the variable - * `ref` that are selected by the slice parameters. The slice parameters - * `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. - *

- * NOTE this op currently does not support broadcasting and so `value`'s - * shape must be exactly the shape produced by the slice of `ref`. + *

+ * The values of `value` are assigned to the positions in the variable `ref` that are selected by the slice + * parameters. The slice parameters `begin`, `end`, `strides`, etc. work exactly as in `StridedSlice`. + *

+ * NOTE this op currently does not support broadcasting and so `value`'s shape must be exactly the shape produced by + * the slice of `ref`. * * @param data type for {@code outputRef()} output * @param scope current scope diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index 335b2c7825e..cbb818597b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -128,26 +128,32 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { * entry can be one of the following: *

* - An ellipsis (...) using {@link Indices#ellipsis()}. Ellipses are used to imply zero or more dimensions of - * full-dimension selection and are produced using `ellipsis_mask`. For example, `foo[...]` is the identity slice. + * full-dimension selection. For example, {@code stridedSlice(foo, Indices.ellipsis()} is the identity slice. *

- * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension and is produced using - * `new_axis_mask`. For example, `foo[:, ...]` where `foo` is shape `(3, 4)` produces a `(1, 3, 4)` tensor. + * - A new axis using {@link Indices#newAxis()}. This is used to insert a new shape=1 dimension. + * For example, `{@code stridedSlice(foo, Indices.newAxis())} where {@code foo} is shape {@code (3, 4)} + * produces a {@code (1, 3, 4)} tensor. *

- * - A range `begin:end:stride` using {@link Indices#slice(Long, Long, long)} Index.slice()}. This is used to specify - * how much to choose from a given dimension. `stride` can be any integer but 0. `begin` is an integer which - * represents the index of the first value to select while `end` represents the index of the last value to select. The - * number of values selected in each dimension is `end - begin` if `stride > 0` and `begin - end` if `stride < 0`. - * `begin` and `end` can be negative where `-1` is the last element, `-2` is the second to last. `begin_mask` controls - * whether to replace the explicitly given `begin` with an implicit effective value of `0` if `stride > 0` and `-1` if - * `stride < 0`. `end_mask` is analogous but produces the number required to create the largest open interval. For - * example, given a shape `(3,)` tensor `foo[:]`, the effective `begin` and `end` are `0` and `3`. Do not assume this - * is equivalent to `foo[0:-1]` which has an effective `begin` and `end` of `0` and `2`. Another example is - * `foo[-2::-1]` which reverses the first dimension of a tensor while dropping the last two (in the original order - * elements). For example `foo = [1,2,3,4]; foo[-2::-1]` is `[4,3]`. + * - A range {@code begin:end:stride} using {@link Indices#slice(Long, Long, long)} Index.slice()} or {@link Indices#all()}. This is used to specify + * how much to choose from a given dimension. {@code stride} can be any integer but 0. {@code begin} is an integer which + * represents the index of the first value to select while {@code end} represents the index of the last value to select + * (exclusive). Begin and end can be null, in which case the index begins or ends at the beginning or end of the dimension, + * respectively (reversed if stride is negative). When both are null, {@code slice()} is the same as {@code all()}. + * The number of values selected in each dimension is {@code end - begin} if {@code stride > 0} and {@code begin - end} + * if {@code stride < 0}. {@code begin} and {@code end} can be negative where {@code -1} is the last element, {@code -2} + * is the second to last. For example, given a shape {@code (3,)} tensor {@code stridedSlice(foo, Indices.all())}, the + * effective {@code begin} and {@code end} are {@code 0} and {@code 3}. Do not assume this is equivalent to + * {@code stridedSlice(foo, Indices.slice(0, -1))} which has an effective {@code begin} and {@code end} of {@code 0} and + * {@code 2}. Another example is {@code stridedSlice(foo, Indices.slice(-2, null, -1))} which reverses the first dimension + * of a tensor while dropping the last two (in the original order elements). For example {@code foo = [1,2,3,4]; + * stridedSlice(foo, Indices.slice(-2, null, -1)} is {@code [4,3]}. *

* - A single index using {@link Indices#at(long)}. This is used to keep only elements that have a given index. For - * example (`foo[2, :]` on a shape `(5,6)` tensor produces a shape `(6,)` tensor. This is encoded in `begin` and `end` - * and `shrink_axis_mask`. + * example ({@code stridedSlice(foo, Indices.at(2))} on a shape {@code (5,6)} tensor produces a shape {@code (6,)} tensor. + * The dimension can be kept with size one using {@link Indices#at(long, boolean)}. + *

+ * These semantics generally follow NumPy's indexing semantics, which can be found here: + * https://numpy.org/doc/stable/reference/arrays.indexing.html *

* * Requirements: From 8636f90954e23139c0c69d31c296e34f94eaf9c6 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 7 Jan 2021 17:43:22 -0800 Subject: [PATCH 21/24] Split Slice into nullability cases Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/index/Indices.java | 176 +++++++----------- .../org/tensorflow/ndarray/index/Slice.java | 36 +--- .../tensorflow/ndarray/index/SliceFrom.java | 86 +++++++++ .../org/tensorflow/ndarray/index/SliceTo.java | 86 +++++++++ .../org/tensorflow/ndarray/index/Step.java | 83 +++++++++ .../org/tensorflow/ndarray/IndexTest.java | 10 +- .../org/tensorflow/op/core/IndexingTest.java | 4 +- 7 files changed, 336 insertions(+), 145 deletions(-) create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java create mode 100644 ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index 425e8797ce6..555d7c1a496 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -34,8 +34,8 @@ public final class Indices { * single element and therefore is excluded from the computation of the rank. * *

For example, given a 3D matrix on the axis [x, y, z], if - * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its - * number of elements is {@code x.numElements()} + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} * * @param coord coordinate of the element on the indexed axis * @return index @@ -65,12 +65,12 @@ public static Index at(NdArray coord) { * A coordinate that selects a specific element on a given dimension. * *

When this index is applied to a given dimension, the dimension is resolved as a - * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. - * If {@code} keepDim is true, the dimension is collapsed down to one element. + * single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code} + * keepDim is true, the dimension is collapsed down to one element. * *

For example, given a 3D matrix on the axis [x, y, z], if - * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its - * number of elements is {@code x.numElements()} + * {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is + * {@code x.numElements()} * * @param coord coordinate of the element on the indexed axis * @param keepDim whether to remove the dimension. @@ -89,8 +89,8 @@ public static Index at(long coord, boolean keepDim) { * If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed. * * @param coord scalar indicating the coordinate of the element on the indexed axis - * @return index * @param keepDim whether to remove the dimension. + * @return index * @throws IllegalRankException if {@code coord} is not a scalar (rank 0) */ public static Index at(NdArray coord, boolean keepDim) { @@ -149,8 +149,7 @@ public static Index seq(NdArray coords) { } /** - * An index that returns only elements found at an even position in the - * original dimension. + * An index that returns only elements found at an even position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code even()} returns x0, x2, ..., xn-2 @@ -158,12 +157,11 @@ public static Index seq(NdArray coords) { * @return index */ public static Index even() { - return slice(null, null, 2); + return step(2); } /** - * An index that returns only elements found at an odd position in the - * original dimension. + * An index that returns only elements found at an odd position in the original dimension. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and n is even, * {@code odd()} returns x1, x3, ..., xn-1 @@ -171,7 +169,7 @@ public static Index even() { * @return index */ public static Index odd() { - return slice(1, null, 2); + return sliceFrom(1, 2); } /** @@ -180,16 +178,15 @@ public static Index odd() { *

For example, given a vector with {@code n} elements on the {@code x} axis, * {@code step(k)} returns x0, xk, xk*2, ... * - * @param stepLength the number of elements between each steps + * @param stride the number of elements between each steps * @return index */ - public static Index step(long stepLength) { - return slice(null, null, stepLength); + public static Index step(long stride) { + return new Step(stride); } /** - * An index that returns only elements on a given dimension starting at a - * specific coordinate. + * An index that returns only elements on a given dimension starting at a specific coordinate. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code from(k)} returns xk, xk+1, ..., xn-1 @@ -198,42 +195,40 @@ public static Index step(long stepLength) { * @return index */ public static Index sliceFrom(long start) { - return slice(start, null); + return sliceFrom(start, 1); } /** - * An index that returns only elements on a given dimension up to a - * specific coordinate. + * An index that returns only elements on a given dimension starting at a specific coordinate, using the given + * stride. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, - * {@code to(k)} returns x0, x1, ..., xk + * {@code from(k)} returns xk, xk+1, ..., xn-1 * - * @param end coordinate of the last element of the sequence (exclusive) + * @param start coordinate of the first element of the sequence + * @param stride the stride to use * @return index + * @see #slice(long, long, long) */ - public static Index sliceTo(long end) { - return slice(null, end); + public static Index sliceFrom(long start, long stride) { + return new SliceFrom(start, stride); } /** - * An index that returns only elements on a given dimension starting at a - * specific coordinate, using the given stride. + * An index that returns only elements on a given dimension up to a specific coordinate. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, - * {@code from(k)} returns xk, xk+1, ..., xn-1 + * {@code to(k)} returns x0, x1, ..., xk * - * @param start coordinate of the first element of the sequence - * @param stride the stride to use + * @param end coordinate of the last element of the sequence (exclusive) * @return index - * @see #slice(long, long, long) */ - public static Index sliceFrom(long start, long stride) { - return slice(start, null, stride); + public static Index sliceTo(long end) { + return sliceTo(end, 1); } /** - * An index that returns only elements on a given dimension up to a - * specific coordinate, using the given stride. + * An index that returns only elements on a given dimension up to a specific coordinate, using the given stride. * *

For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k}, * {@code to(k)} returns x0, x1, ..., xk @@ -244,7 +239,7 @@ public static Index sliceFrom(long start, long stride) { * @see #slice(long, long, long) */ public static Index sliceTo(long end, long stride) { - return slice(null, end, stride); + return new SliceTo(end, stride); } /** @@ -272,16 +267,15 @@ public static Index range(long start, long end) { public static Index flip() { return slice(null, null, -1); } - + /** - * An index that returns elements according to an hyperslab defined by {@code start}, - * {@code stride}, {@code count}, {@code block}. See {@link Hyperslab}. - * + * An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count}, + * {@code block}. See {@link Hyperslab}. + * * @param start Starting location for the hyperslab. * @param stride The number of elements to separate each element or block to be selected. * @param count The number of elements or blocks to select along the dimension. * @param block The size of the block selected from the dimension. - * * @return index */ public static Index hyperslab(long start, long stride, long count, long block) { @@ -293,123 +287,87 @@ public static Index hyperslab(long start, long stride, long count, long block) { * * @return index */ - public static Index newAxis(){ + public static Index newAxis() { return NewAxis.INSTANCE; } /** - * An index that expands to fill all available source dimensions. - * Works the same as Python's {@code ...}. - * @see #expand() + * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. + * * @return index + * @see #expand() */ - public static Index ellipsis(){ + public static Index ellipsis() { return Ellipsis.INSTANCE; } /** - * An index that expands to fill all available source dimensions. - * Works the same as Python's {@code ...}. + * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. * * @return index */ - public static Index expand(){ + public static Index expand() { return ellipsis(); } /** - * An index that returns elements between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. *

* Analogous to Python's {@code :} slice syntax. * * @return index */ - public static Index slice(Long start, Long end){ + public static Index slice(long start, long end) { return slice(start, end, 1); } /** - * An index that returns elements between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. *

* Analogous to Python's {@code :} slice syntax. * * @return index */ - public static Index slice(long start, Long end){ - return slice(start, end, 1); - } - - /** - * An index that returns elements between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. - *

- * Analogous to Python's {@code :} slice syntax. - * - * @return index - */ - public static Index slice(Long start, long end){ - return slice(start, end, 1); + public static Index slice(long start, long end, long stride) { + return new Slice(start, end, stride); } /** - * An index that returns elements between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code + * null}, starts or ends at the beginning or the end, respectively. *

* Analogous to Python's {@code :} slice syntax. * * @return index */ - public static Index slice(long start, long end){ + public static Index slice(Long start, Long end) { return slice(start, end, 1); } /** - * An index that returns every {@code stride}-th element between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. + * An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or + * {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. *

* Analogous to Python's {@code :} slice syntax. * * @return index */ - public static Index slice(Long start, Long end, long stride){ - return new Slice(start, end, stride); - } - - /** - * An index that returns every {@code stride}-th element between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. - *

- * Analogous to Python's {@code :} slice syntax. - * - * @return index - */ - public static Index slice(long start, Long end, long stride){ - return new Slice(start, end, stride); - } + public static Index slice(Long start, Long end, long stride) { + if (start == null && end == null) { + if (stride == 1) { + return Indices.all(); + } else { + return Indices.step(stride); + } + } else if (start == null) { + return Indices.sliceTo(end, stride); + } else if (end == null) { + return Indices.sliceFrom(start, stride); + } - /** - * An index that returns every {@code stride}-th element between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. - *

- * Analogous to Python's {@code :} slice syntax. - * - * @return index - */ - public static Index slice(Long start, long end, long stride){ - return new Slice(start, end, stride); + return slice(start.longValue(), end.longValue(), stride); } - /** - * An index that returns every {@code stride}-th element between {@code start} and {@code end}. - * If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively. - *

- * Analogous to Python's {@code :} slice syntax. - * - * @return index - */ - public static Index slice(long start, long end, long stride){ - return new Slice(start, end, stride); - } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java index 12ce613842a..1be4368261c 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Slice.java @@ -21,7 +21,7 @@ final class Slice implements Index { - Slice(Long start, Long end, long stride) { + Slice(long start, long end, long stride) { this.start = start; this.end = end; this.stride = stride; @@ -45,12 +45,12 @@ public long mapCoordinate(long coordinate, Dimension dim) { @Override public long begin() { - return start == null ? 0 : start; + return start; } @Override public long end() { - return end == null ? 0 : end; + return end; } @Override @@ -58,16 +58,6 @@ public long stride() { return stride; } - @Override - public boolean beginMask() { - return start == null; - } - - @Override - public boolean endMask() { - return end == null; - } - @Override public String toString() { return new StringJoiner(", ", Slice.class.getSimpleName() + "(", ")") @@ -78,13 +68,7 @@ public String toString() { } private long start(Dimension dim) { - if (start == null) { - if (stride > 0) { - return 0; - } - - return dim.numElements() - 1; // it's inclusive - } else if (start < 0) { + if (start < 0) { return dim.numElements() + start; } @@ -92,20 +76,14 @@ private long start(Dimension dim) { } private long end(Dimension dim) { - if (end == null) { - if (stride > 0) { - return dim.numElements(); - } else { - return -1; // it's exclusive - } - } else if (end < 0) { + if (end < 0) { return dim.numElements() + end; } else { return end; } } - private final Long start; - private final Long end; + private final long start; + private final long end; private final long stride; } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java new file mode 100644 index 00000000000..c968a325cf7 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceFrom.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceFrom implements Index { + + SliceFrom(long start, long stride) { + this.start = start; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long begin() { + return start; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceFrom.class.getSimpleName() + "(", ")") + .add("start=" + start) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (start < 0) { + return dim.numElements() + start; + } + + return start; + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long start; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java new file mode 100644 index 00000000000..761d1d52a3a --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/SliceTo.java @@ -0,0 +1,86 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class SliceTo implements Index { + + SliceTo(long end, long stride) { + this.end = end; + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public long end() { + return end; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", SliceTo.class.getSimpleName() + "(", ")") + .add("end=" + end) + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (end < 0) { + return dim.numElements() + end; + } else { + return end; + } + } + + private final long end; + private final long stride; +} diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java new file mode 100644 index 00000000000..c9a21c507b6 --- /dev/null +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Step.java @@ -0,0 +1,83 @@ +/* + Copyright 2020 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow.ndarray.index; + +import java.util.StringJoiner; +import org.tensorflow.ndarray.impl.dimension.Dimension; + +final class Step implements Index { + + Step(long stride) { + this.stride = stride; + + if (stride == 0) { + throw new IllegalArgumentException("Can not have a stride of 0"); + } + } + + @Override + public long numElements(Dimension dim) { + long length = end(dim) - start(dim); + + return (length / stride) + (length % stride != 0 ? 1 : 0); + } + + @Override + public long mapCoordinate(long coordinate, Dimension dim) { + return start(dim) + stride * coordinate; + } + + @Override + public boolean beginMask() { + return true; + } + + @Override + public boolean endMask() { + return true; + } + + @Override + public long stride() { + return stride; + } + + @Override + public String toString() { + return new StringJoiner(", ", Step.class.getSimpleName() + "(", ")") + .add("stride=" + stride) + .toString(); + } + + private long start(Dimension dim) { + if (stride > 0) { + return 0; + } + + return dim.numElements() - 1; // it's inclusive + } + + private long end(Dimension dim) { + if (stride > 0) { + return dim.numElements(); + } else { + return -1; // it's exclusive + } + } + + private final long stride; +} diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java index 70c6f21d30c..6f92dab9b99 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java @@ -25,19 +25,19 @@ public class IndexTest { @Test public void testNullConversions(){ - assertTrue(Indices.slice(null, 0).beginMask(), + assertTrue(Indices.slice(null, 0L).beginMask(), "Passed null for slice start but didn't set begin mask"); - assertTrue(Indices.slice(null, 0).beginMask(), + assertTrue(Indices.slice(null, 0L).beginMask(), "Passed null for slice start but didn't set begin mask"); assertTrue(Indices.slice(null, null).beginMask(), "Passed null for slice start but didn't set begin mask"); - assertTrue(Indices.slice(0, null).endMask(), + assertTrue(Indices.slice(0L, null).endMask(), "Passed null for slice end but didn't set end mask"); - assertTrue(Indices.slice(0, null).endMask(), + assertTrue(Indices.slice(0L, null).endMask(), "Passed null for slice end but didn't set end mask"); assertTrue(Indices.slice(null, null).endMask(), @@ -135,7 +135,7 @@ public void testSlice(){ scalar.setInt((int)coords[2]) ); - IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.slice(null, 3), Indices.all()); + IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.sliceTo(3), Indices.all()); assertEquals(Shape.of(5, 3, 5), slice1.shape()); assertEquals(0, slice1.getInt(0, 0, 0)); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index a4fcba2caec..6e86573b7cf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -35,8 +35,8 @@ public class IndexingTest { Indices.all(), Indices.newAxis(), Indices.ellipsis(), - Indices.slice(null, 4), - Indices.slice(4, null, 2) + Indices.sliceTo( 4), + Indices.sliceFrom(4, 2) }; @Test From fffdabdfd095a7a6e81afca9d066d1ba0c7eacfe Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 8 Jan 2021 14:10:16 -0800 Subject: [PATCH 22/24] Change benchmark to fork once by default Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java index 8acfdff7721..fb7022bc830 100644 --- a/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java +++ b/ndarray/src/test/java/org/tensorflow/ndarray/benchmark/NdArrayBenchmark.java @@ -38,7 +38,7 @@ import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.StdArrays; -@Fork(value = 0, jvmArgs = {"-Xms4G", "-Xmx4G"}) +@Fork(value = 1, jvmArgs = {"-Xms4G", "-Xmx4G"}) @BenchmarkMode(Mode.AverageTime) @Warmup(iterations = 3) @Measurement(iterations = 5) From 3d6c2436f6cb152b9d569f64180e949f07d1120a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 24 Jan 2021 14:49:58 -0800 Subject: [PATCH 23/24] Remove expand Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ndarray/index/Indices.java | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java index 555d7c1a496..346ab705595 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java @@ -295,21 +295,11 @@ public static Index newAxis() { * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. * * @return index - * @see #expand() */ public static Index ellipsis() { return Ellipsis.INSTANCE; } - /** - * An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}. - * - * @return index - */ - public static Index expand() { - return ellipsis(); - } - /** * An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code * null}, starts or ends at the beginning or the end, respectively. From 6e0e86fa7c5f068bb56f143f8fe7da2bfb0d2830 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 24 Jan 2021 17:50:26 -0800 Subject: [PATCH 24/24] remove tensor references Signed-off-by: Ryan Nett --- .../org/tensorflow/ndarray/index/Hyperslab.java | 2 +- .../java/org/tensorflow/ndarray/index/Index.java | 15 ++++++++------- .../org/tensorflow/ndarray/index/Sequence.java | 2 +- .../tensorflow/op/core/StridedSliceHelper.java | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java index dc8eb6244dd..55c4e510748 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Hyperslab.java @@ -84,7 +84,7 @@ public String toString() { } @Override - public boolean tensorSupport() { + public boolean isStridedSlicingCompliant() { return false; } } diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java index 028405ce243..617ca4d474b 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Index.java @@ -87,42 +87,43 @@ default boolean isEllipsis() { } /** - * Get whether the Index supports Tensor conversion. + * Get whether the Index supports strided slice style indexing (using start, end, stride, and flags, i.e. TensorFlow's). */ - default boolean tensorSupport() { + default boolean isStridedSlicingCompliant() { return true; } /** - * Get the start of the index, for Tensor conversion. + * Get the start of the index, for strided slice style indexing. */ default long begin() { return 0; } /** - * Get the end of the index, for Tensor conversion. + * Get the end of the index, strided slice style indexing. */ default long end() { return 0; } /** - * Get the stride of the index, for Tensor conversion. + * Get the stride of the index, for strided slice style indexing. */ default long stride() { return 1; } /** - * Get whether the Index should start at the beginning of the dimension, for Tensor conversion. + * Get whether the Index should start at the beginning of the dimension, for strided slice style indexing. */ default boolean beginMask() { return false; } /** - * Get whether the Index should end at the beginning of the dimension, for Tensor conversion. + * Get whether the Index should end at the beginning of the dimension, for strided slice style indexing. */ default boolean endMask() { return false; diff --git a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java index 6b8951ba2d1..5b93e434e54 100644 --- a/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java +++ b/ndarray/src/main/java/org/tensorflow/ndarray/index/Sequence.java @@ -46,7 +46,7 @@ public String toString() { } @Override - public boolean tensorSupport() { + public boolean isStridedSlicingCompliant() { return false; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java index cbb818597b8..e97934ee312 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/StridedSliceHelper.java @@ -71,7 +71,7 @@ static StridedSliceArgs mergeIndexes(Index[] indices) { idx = Indices.all(); } - if (!idx.tensorSupport()) { + if (!idx.isStridedSlicingCompliant()) { throw new UnsupportedOperationException("Index " + idx + " is not supported for Tensors"); }