Skip to content

Commit 665816c

Browse files
committed
Tests
Signed-off-by: Ryan Nett <[email protected]>
1 parent f27f570 commit 665816c

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

ndarray/src/test/java/org/tensorflow/ndarray/IndexTest.java

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19+
import static org.junit.jupiter.api.Assertions.assertEquals;
1920
import static org.junit.jupiter.api.Assertions.assertTrue;
2021

2122
import org.junit.jupiter.api.Test;
@@ -42,4 +43,163 @@ public void testNullConversions(){
4243
assertTrue(Indices.slice(null, null).endMask(),
4344
"Passed null for slice end but didn't set end mask");
4445
}
46+
47+
@Test
48+
public void testNewaxis(){
49+
IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
50+
51+
matrix3d.scalars().forEachIndexed((coords, scalar) ->
52+
scalar.setInt((int)coords[2])
53+
);
54+
55+
IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.all(), Indices.newAxis());
56+
57+
assertEquals(Shape.of(5, 4, 5, 1), slice1.shape());
58+
assertEquals(0, slice1.getInt(0, 0, 0, 0));
59+
assertEquals(1, slice1.getInt(0, 0, 1, 0));
60+
assertEquals(4, slice1.getInt(0, 0, 4, 0));
61+
assertEquals(2, slice1.getInt(0, 1, 2, 0));
62+
63+
IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.all());
64+
65+
assertEquals(Shape.of(5, 4, 1, 5), slice2.shape());
66+
assertEquals(0, slice2.getInt(0, 0, 0, 0));
67+
assertEquals(1, slice2.getInt(0, 0, 0, 1));
68+
assertEquals(4, slice2.getInt(0, 0, 0, 4));
69+
assertEquals(2, slice2.getInt(0, 1, 0, 2));
70+
71+
IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.newAxis(), Indices.all(), Indices.all());
72+
73+
assertEquals(Shape.of(5, 1, 4, 5), slice3.shape());
74+
assertEquals(0, slice3.getInt(0, 0, 0, 0));
75+
assertEquals(1, slice3.getInt(0, 0, 0, 1));
76+
assertEquals(4, slice3.getInt(0, 0, 0, 4));
77+
assertEquals(2, slice3.getInt(0, 0, 1, 2));
78+
79+
IntNdArray slice4 = matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.all());
80+
81+
assertEquals(Shape.of(1, 5, 4, 5), slice4.shape());
82+
assertEquals(0, slice4.getInt(0, 0, 0, 0));
83+
assertEquals(1, slice4.getInt(0, 0, 0, 1));
84+
assertEquals(4, slice4.getInt(0, 0, 0, 4));
85+
assertEquals(2, slice4.getInt(0, 0, 1, 2));
86+
87+
}
88+
89+
@Test
90+
public void testEllipsis(){
91+
IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
92+
93+
matrix3d.scalars().forEachIndexed((coords, scalar) ->
94+
scalar.setInt((int)coords[2])
95+
);
96+
97+
assertEquals(
98+
matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0)),
99+
matrix3d.slice(Indices.ellipsis(), Indices.at(0))
100+
);
101+
102+
assertEquals(
103+
matrix3d.slice(Indices.at(0), Indices.all(), Indices.all()),
104+
matrix3d.slice(Indices.at(0), Indices.ellipsis())
105+
);
106+
107+
assertEquals(
108+
matrix3d.slice(Indices.at(0), Indices.all(), Indices.at(0)),
109+
matrix3d.slice(Indices.at(0), Indices.ellipsis(), Indices.at(0))
110+
);
111+
112+
// newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this
113+
114+
assertEquals(
115+
matrix3d.slice(Indices.all(), Indices.all(), Indices.newAxis(), Indices.at(0)),
116+
matrix3d.slice(Indices.ellipsis(), Indices.newAxis(), Indices.at(0))
117+
);
118+
119+
assertEquals(
120+
matrix3d.slice(Indices.newAxis(), Indices.all(), Indices.all(), Indices.at(0)),
121+
matrix3d.slice(Indices.newAxis(), Indices.ellipsis(), Indices.at(0))
122+
);
123+
124+
assertEquals(
125+
matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0), Indices.newAxis()),
126+
matrix3d.slice(Indices.ellipsis(), Indices.at(0), Indices.newAxis())
127+
);
128+
}
129+
130+
@Test
131+
public void testSlice(){
132+
IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
133+
134+
matrix3d.scalars().forEachIndexed((coords, scalar) ->
135+
scalar.setInt((int)coords[2])
136+
);
137+
138+
IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.slice(null, 3), Indices.all());
139+
140+
assertEquals(Shape.of(5, 3, 5), slice1.shape());
141+
assertEquals(0, slice1.getInt(0, 0, 0));
142+
assertEquals(1, slice1.getInt(0, 0, 1));
143+
assertEquals(2, slice1.getInt(0, 1, 2));
144+
145+
IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4));
146+
147+
assertEquals(Shape.of(5, 4, 3), slice2.shape());
148+
assertEquals(1, slice2.getInt(0, 0, 0));
149+
assertEquals(3, slice2.getInt(0, 0, 2));
150+
assertEquals(2, slice2.getInt(0, 1, 1));
151+
152+
assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, -1)));
153+
154+
assertEquals(slice2, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-4, -1)));
155+
156+
assertEquals(Shape.of(5, 4, 0), matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(1, 4, -2)).shape());
157+
158+
IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(4, 1, -2));
159+
160+
assertEquals(Shape.of(5, 4, 2), slice3.shape());
161+
assertEquals(4, slice3.getInt(0, 0, 0));
162+
assertEquals(2, slice3.getInt(0, 1, 1));
163+
164+
assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, 1, -2)));
165+
166+
assertEquals(slice3, matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(-1, -4, -2)));
167+
168+
IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.slice(null, null, -1));
169+
170+
assertEquals(Shape.of(5, 4, 5), slice4.shape());
171+
assertEquals(4, slice4.getInt(0, 0, 0));
172+
assertEquals(3, slice4.getInt(0, 0, 1));
173+
assertEquals(2, slice4.getInt(0, 1, 2));
174+
}
175+
176+
@Test
177+
public void testAt(){
178+
IntNdArray matrix3d = NdArrays.ofInts(Shape.of(5, 4, 5));
179+
180+
matrix3d.scalars().forEachIndexed((coords, scalar) ->
181+
scalar.setInt((int)coords[2])
182+
);
183+
184+
IntNdArray slice1 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(0));
185+
186+
assertEquals(Shape.of(5, 4), slice1.shape());
187+
assertEquals(0, slice1.getInt(0, 0));
188+
189+
IntNdArray slice2 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(3));
190+
191+
assertEquals(Shape.of(5, 4), slice2.shape());
192+
assertEquals(3, slice2.getInt(0, 0));
193+
194+
IntNdArray slice3 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3));
195+
196+
assertEquals(Shape.of(5, 4), slice3.shape());
197+
assertEquals(2, slice3.getInt(0, 0));
198+
199+
IntNdArray slice4 = matrix3d.slice(Indices.all(), Indices.all(), Indices.at(-3, true));
200+
201+
assertEquals(Shape.of(5, 4, 1), slice4.shape());
202+
assertEquals(2, slice4.getInt(0, 0, 0));
203+
}
204+
45205
}

0 commit comments

Comments
 (0)