Skip to content

Commit 4305031

Browse files
committed
Split Slice into nullability cases
Signed-off-by: Ryan Nett <[email protected]>
1 parent a6756a2 commit 4305031

File tree

7 files changed

+336
-145
lines changed

7 files changed

+336
-145
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/index/Indices.java

Lines changed: 67 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ public final class Indices {
3434
* single element and therefore is excluded from the computation of the rank.
3535
*
3636
* <p>For example, given a 3D matrix on the axis [x, y, z], if
37-
* {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its
38-
* number of elements is {@code x.numElements()}
37+
* {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is
38+
* {@code x.numElements()}
3939
*
4040
* @param coord coordinate of the element on the indexed axis
4141
* @return index
@@ -65,12 +65,12 @@ public static Index at(NdArray<? extends Number> coord) {
6565
* A coordinate that selects a specific element on a given dimension.
6666
*
6767
* <p>When this index is applied to a given dimension, the dimension is resolved as a
68-
* single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank.
69-
* If {@code} keepDim is true, the dimension is collapsed down to one element.
68+
* single element and therefore, if {@code keepDim} is false, is excluded from the computation of the rank. If {@code}
69+
* keepDim is true, the dimension is collapsed down to one element.
7070
*
7171
* <p>For example, given a 3D matrix on the axis [x, y, z], if
72-
* {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its
73-
* number of elements is {@code x.numElements()}
72+
* {@code matrix.slice(all(), at(0), at(0)}, then the rank of the returned slice is 1 and its number of elements is
73+
* {@code x.numElements()}
7474
*
7575
* @param coord coordinate of the element on the indexed axis
7676
* @param keepDim whether to remove the dimension.
@@ -89,8 +89,8 @@ public static Index at(long coord, boolean keepDim) {
8989
* If {@code} keepDim is true, the dimension is collapsed down to one element instead of being removed.
9090
*
9191
* @param coord scalar indicating the coordinate of the element on the indexed axis
92-
* @return index
9392
* @param keepDim whether to remove the dimension.
93+
* @return index
9494
* @throws IllegalRankException if {@code coord} is not a scalar (rank 0)
9595
*/
9696
public static Index at(NdArray<? extends Number> coord, boolean keepDim) {
@@ -149,29 +149,27 @@ public static Index seq(NdArray<? extends Number> coords) {
149149
}
150150

151151
/**
152-
* An index that returns only elements found at an even position in the
153-
* original dimension.
152+
* An index that returns only elements found at an even position in the original dimension.
154153
*
155154
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and n is even,
156155
* {@code even()} returns x<sub>0</sub>, x<sub>2</sub>, ..., x<sub>n-2</sub>
157156
*
158157
* @return index
159158
*/
160159
public static Index even() {
161-
return slice(null, null, 2);
160+
return step(2);
162161
}
163162

164163
/**
165-
* An index that returns only elements found at an odd position in the
166-
* original dimension.
164+
* An index that returns only elements found at an odd position in the original dimension.
167165
*
168166
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and n is even,
169167
* {@code odd()} returns x<sub>1</sub>, x<sub>3</sub>, ..., x<sub>n-1</sub>
170168
*
171169
* @return index
172170
*/
173171
public static Index odd() {
174-
return slice(1, null, 2);
172+
return sliceFrom(1, 2);
175173
}
176174

177175
/**
@@ -180,16 +178,15 @@ public static Index odd() {
180178
* <p>For example, given a vector with {@code n} elements on the {@code x} axis,
181179
* {@code step(k)} returns x<sub>0</sub>, x<sub>k</sub>, x<sub>k*2</sub>, ...
182180
*
183-
* @param stepLength the number of elements between each steps
181+
* @param stride the number of elements between each steps
184182
* @return index
185183
*/
186-
public static Index step(long stepLength) {
187-
return slice(null, null, stepLength);
184+
public static Index step(long stride) {
185+
return new Step(stride);
188186
}
189187

190188
/**
191-
* An index that returns only elements on a given dimension starting at a
192-
* specific coordinate.
189+
* An index that returns only elements on a given dimension starting at a specific coordinate.
193190
*
194191
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
195192
* {@code from(k)} returns x<sub>k</sub>, x<sub>k+1</sub>, ..., x<sub>n-1</sub>
@@ -198,42 +195,40 @@ public static Index step(long stepLength) {
198195
* @return index
199196
*/
200197
public static Index sliceFrom(long start) {
201-
return slice(start, null);
198+
return sliceFrom(start, 1);
202199
}
203200

204201
/**
205-
* An index that returns only elements on a given dimension up to a
206-
* specific coordinate.
202+
* An index that returns only elements on a given dimension starting at a specific coordinate, using the given
203+
* stride.
207204
*
208205
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
209-
* {@code to(k)} returns x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>k</sub>
206+
* {@code from(k)} returns x<sub>k</sub>, x<sub>k+1</sub>, ..., x<sub>n-1</sub>
210207
*
211-
* @param end coordinate of the last element of the sequence (exclusive)
208+
* @param start coordinate of the first element of the sequence
209+
* @param stride the stride to use
212210
* @return index
211+
* @see #slice(long, long, long)
213212
*/
214-
public static Index sliceTo(long end) {
215-
return slice(null, end);
213+
public static Index sliceFrom(long start, long stride) {
214+
return new SliceFrom(start, stride);
216215
}
217216

218217
/**
219-
* An index that returns only elements on a given dimension starting at a
220-
* specific coordinate, using the given stride.
218+
* An index that returns only elements on a given dimension up to a specific coordinate.
221219
*
222220
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
223-
* {@code from(k)} returns x<sub>k</sub>, x<sub>k+1</sub>, ..., x<sub>n-1</sub>
221+
* {@code to(k)} returns x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>k</sub>
224222
*
225-
* @param start coordinate of the first element of the sequence
226-
* @param stride the stride to use
223+
* @param end coordinate of the last element of the sequence (exclusive)
227224
* @return index
228-
* @see #slice(long, long, long)
229225
*/
230-
public static Index sliceFrom(long start, long stride) {
231-
return slice(start, null, stride);
226+
public static Index sliceTo(long end) {
227+
return sliceTo(end, 1);
232228
}
233229

234230
/**
235-
* An index that returns only elements on a given dimension up to a
236-
* specific coordinate, using the given stride.
231+
* An index that returns only elements on a given dimension up to a specific coordinate, using the given stride.
237232
*
238233
* <p>For example, given a vector with {@code n} elements on the {@code x} axis, and {@code n > k},
239234
* {@code to(k)} returns x<sub>0</sub>, x<sub>1</sub>, ..., x<sub>k</sub>
@@ -244,7 +239,7 @@ public static Index sliceFrom(long start, long stride) {
244239
* @see #slice(long, long, long)
245240
*/
246241
public static Index sliceTo(long end, long stride) {
247-
return slice(null, end, stride);
242+
return new SliceTo(end, stride);
248243
}
249244

250245
/**
@@ -272,16 +267,15 @@ public static Index range(long start, long end) {
272267
public static Index flip() {
273268
return slice(null, null, -1);
274269
}
275-
270+
276271
/**
277-
* An index that returns elements according to an hyperslab defined by {@code start},
278-
* {@code stride}, {@code count}, {@code block}. See {@link Hyperslab}.
279-
*
272+
* An index that returns elements according to an hyperslab defined by {@code start}, {@code stride}, {@code count},
273+
* {@code block}. See {@link Hyperslab}.
274+
*
280275
* @param start Starting location for the hyperslab.
281276
* @param stride The number of elements to separate each element or block to be selected.
282277
* @param count The number of elements or blocks to select along the dimension.
283278
* @param block The size of the block selected from the dimension.
284-
*
285279
* @return index
286280
*/
287281
public static Index hyperslab(long start, long stride, long count, long block) {
@@ -293,123 +287,87 @@ public static Index hyperslab(long start, long stride, long count, long block) {
293287
*
294288
* @return index
295289
*/
296-
public static Index newAxis(){
290+
public static Index newAxis() {
297291
return NewAxis.INSTANCE;
298292
}
299293

300294
/**
301-
* An index that expands to fill all available source dimensions.
302-
* Works the same as Python's {@code ...}.
303-
* @see #expand()
295+
* An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}.
296+
*
304297
* @return index
298+
* @see #expand()
305299
*/
306-
public static Index ellipsis(){
300+
public static Index ellipsis() {
307301
return Ellipsis.INSTANCE;
308302
}
309303

310304
/**
311-
* An index that expands to fill all available source dimensions.
312-
* Works the same as Python's {@code ...}.
305+
* An index that expands to fill all available source dimensions. Works the same as Python's {@code ...}.
313306
*
314307
* @return index
315308
*/
316-
public static Index expand(){
309+
public static Index expand() {
317310
return ellipsis();
318311
}
319312

320313
/**
321-
* An index that returns elements between {@code start} and {@code end}.
322-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
314+
* An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code
315+
* null}, starts or ends at the beginning or the end, respectively.
323316
* <p>
324317
* Analogous to Python's {@code :} slice syntax.
325318
*
326319
* @return index
327320
*/
328-
public static Index slice(Long start, Long end){
321+
public static Index slice(long start, long end) {
329322
return slice(start, end, 1);
330323
}
331324

332325
/**
333-
* An index that returns elements between {@code start} and {@code end}.
334-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
326+
* An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or
327+
* {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
335328
* <p>
336329
* Analogous to Python's {@code :} slice syntax.
337330
*
338331
* @return index
339332
*/
340-
public static Index slice(long start, Long end){
341-
return slice(start, end, 1);
342-
}
343-
344-
/**
345-
* An index that returns elements between {@code start} and {@code end}.
346-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
347-
* <p>
348-
* Analogous to Python's {@code :} slice syntax.
349-
*
350-
* @return index
351-
*/
352-
public static Index slice(Long start, long end){
353-
return slice(start, end, 1);
333+
public static Index slice(long start, long end, long stride) {
334+
return new Slice(start, end, stride);
354335
}
355336

356337
/**
357-
* An index that returns elements between {@code start} and {@code end}.
358-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
338+
* An index that returns elements between {@code start} and {@code end}. If {@code start} or {@code end} is {@code
339+
* null}, starts or ends at the beginning or the end, respectively.
359340
* <p>
360341
* Analogous to Python's {@code :} slice syntax.
361342
*
362343
* @return index
363344
*/
364-
public static Index slice(long start, long end){
345+
public static Index slice(Long start, Long end) {
365346
return slice(start, end, 1);
366347
}
367348

368349
/**
369-
* An index that returns every {@code stride}-th element between {@code start} and {@code end}.
370-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
350+
* An index that returns every {@code stride}-th element between {@code start} and {@code end}. If {@code start} or
351+
* {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
371352
* <p>
372353
* Analogous to Python's {@code :} slice syntax.
373354
*
374355
* @return index
375356
*/
376-
public static Index slice(Long start, Long end, long stride){
377-
return new Slice(start, end, stride);
378-
}
379-
380-
/**
381-
* An index that returns every {@code stride}-th element between {@code start} and {@code end}.
382-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
383-
* <p>
384-
* Analogous to Python's {@code :} slice syntax.
385-
*
386-
* @return index
387-
*/
388-
public static Index slice(long start, Long end, long stride){
389-
return new Slice(start, end, stride);
390-
}
357+
public static Index slice(Long start, Long end, long stride) {
358+
if (start == null && end == null) {
359+
if (stride == 1) {
360+
return Indices.all();
361+
} else {
362+
return Indices.step(stride);
363+
}
364+
} else if (start == null) {
365+
return Indices.sliceTo(end, stride);
366+
} else if (end == null) {
367+
return Indices.sliceFrom(start, stride);
368+
}
391369

392-
/**
393-
* An index that returns every {@code stride}-th element between {@code start} and {@code end}.
394-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
395-
* <p>
396-
* Analogous to Python's {@code :} slice syntax.
397-
*
398-
* @return index
399-
*/
400-
public static Index slice(Long start, long end, long stride){
401-
return new Slice(start, end, stride);
370+
return slice(start.longValue(), end.longValue(), stride);
402371
}
403372

404-
/**
405-
* An index that returns every {@code stride}-th element between {@code start} and {@code end}.
406-
* If {@code start} or {@code end} is {@code null}, starts or ends at the beginning or the end, respectively.
407-
* <p>
408-
* Analogous to Python's {@code :} slice syntax.
409-
*
410-
* @return index
411-
*/
412-
public static Index slice(long start, long end, long stride){
413-
return new Slice(start, end, stride);
414-
}
415373
}

0 commit comments

Comments
 (0)