16
16
*/
17
17
package org .tensorflow .ndarray ;
18
18
19
+ import static org .junit .jupiter .api .Assertions .assertEquals ;
19
20
import static org .junit .jupiter .api .Assertions .assertTrue ;
20
21
21
22
import org .junit .jupiter .api .Test ;
@@ -42,4 +43,163 @@ public void testNullConversions(){
42
43
assertTrue (Indices .slice (null , null ).endMask (),
43
44
"Passed null for slice end but didn't set end mask" );
44
45
}
46
+
47
+ @ Test
48
+ public void testNewaxis (){
49
+ IntNdArray matrix3d = NdArrays .ofInts (Shape .of (5 , 4 , 5 ));
50
+
51
+ matrix3d .scalars ().forEachIndexed ((coords , scalar ) ->
52
+ scalar .setInt ((int )coords [2 ])
53
+ );
54
+
55
+ IntNdArray slice1 = matrix3d .slice (Indices .all (), Indices .all (), Indices .all (), Indices .newAxis ());
56
+
57
+ assertEquals (Shape .of (5 , 4 , 5 , 1 ), slice1 .shape ());
58
+ assertEquals (0 , slice1 .getInt (0 , 0 , 0 , 0 ));
59
+ assertEquals (1 , slice1 .getInt (0 , 0 , 1 , 0 ));
60
+ assertEquals (4 , slice1 .getInt (0 , 0 , 4 , 0 ));
61
+ assertEquals (2 , slice1 .getInt (0 , 1 , 2 , 0 ));
62
+
63
+ IntNdArray slice2 = matrix3d .slice (Indices .all (), Indices .all (), Indices .newAxis (), Indices .all ());
64
+
65
+ assertEquals (Shape .of (5 , 4 , 1 , 5 ), slice2 .shape ());
66
+ assertEquals (0 , slice2 .getInt (0 , 0 , 0 , 0 ));
67
+ assertEquals (1 , slice2 .getInt (0 , 0 , 0 , 1 ));
68
+ assertEquals (4 , slice2 .getInt (0 , 0 , 0 , 4 ));
69
+ assertEquals (2 , slice2 .getInt (0 , 1 , 0 , 2 ));
70
+
71
+ IntNdArray slice3 = matrix3d .slice (Indices .all (), Indices .newAxis (), Indices .all (), Indices .all ());
72
+
73
+ assertEquals (Shape .of (5 , 1 , 4 , 5 ), slice3 .shape ());
74
+ assertEquals (0 , slice3 .getInt (0 , 0 , 0 , 0 ));
75
+ assertEquals (1 , slice3 .getInt (0 , 0 , 0 , 1 ));
76
+ assertEquals (4 , slice3 .getInt (0 , 0 , 0 , 4 ));
77
+ assertEquals (2 , slice3 .getInt (0 , 0 , 1 , 2 ));
78
+
79
+ IntNdArray slice4 = matrix3d .slice (Indices .newAxis (), Indices .all (), Indices .all (), Indices .all ());
80
+
81
+ assertEquals (Shape .of (1 , 5 , 4 , 5 ), slice4 .shape ());
82
+ assertEquals (0 , slice4 .getInt (0 , 0 , 0 , 0 ));
83
+ assertEquals (1 , slice4 .getInt (0 , 0 , 0 , 1 ));
84
+ assertEquals (4 , slice4 .getInt (0 , 0 , 0 , 4 ));
85
+ assertEquals (2 , slice4 .getInt (0 , 0 , 1 , 2 ));
86
+
87
+ }
88
+
89
+ @ Test
90
+ public void testEllipsis (){
91
+ IntNdArray matrix3d = NdArrays .ofInts (Shape .of (5 , 4 , 5 ));
92
+
93
+ matrix3d .scalars ().forEachIndexed ((coords , scalar ) ->
94
+ scalar .setInt ((int )coords [2 ])
95
+ );
96
+
97
+ assertEquals (
98
+ matrix3d .slice (Indices .all (), Indices .all (), Indices .at (0 )),
99
+ matrix3d .slice (Indices .ellipsis (), Indices .at (0 ))
100
+ );
101
+
102
+ assertEquals (
103
+ matrix3d .slice (Indices .at (0 ), Indices .all (), Indices .all ()),
104
+ matrix3d .slice (Indices .at (0 ), Indices .ellipsis ())
105
+ );
106
+
107
+ assertEquals (
108
+ matrix3d .slice (Indices .at (0 ), Indices .all (), Indices .at (0 )),
109
+ matrix3d .slice (Indices .at (0 ), Indices .ellipsis (), Indices .at (0 ))
110
+ );
111
+
112
+ // newaxis interacts specially with ellipsis (since it doesn't consume a dimension), test this
113
+
114
+ assertEquals (
115
+ matrix3d .slice (Indices .all (), Indices .all (), Indices .newAxis (), Indices .at (0 )),
116
+ matrix3d .slice (Indices .ellipsis (), Indices .newAxis (), Indices .at (0 ))
117
+ );
118
+
119
+ assertEquals (
120
+ matrix3d .slice (Indices .newAxis (), Indices .all (), Indices .all (), Indices .at (0 )),
121
+ matrix3d .slice (Indices .newAxis (), Indices .ellipsis (), Indices .at (0 ))
122
+ );
123
+
124
+ assertEquals (
125
+ matrix3d .slice (Indices .all (), Indices .all (), Indices .at (0 ), Indices .newAxis ()),
126
+ matrix3d .slice (Indices .ellipsis (), Indices .at (0 ), Indices .newAxis ())
127
+ );
128
+ }
129
+
130
+ @ Test
131
+ public void testSlice (){
132
+ IntNdArray matrix3d = NdArrays .ofInts (Shape .of (5 , 4 , 5 ));
133
+
134
+ matrix3d .scalars ().forEachIndexed ((coords , scalar ) ->
135
+ scalar .setInt ((int )coords [2 ])
136
+ );
137
+
138
+ IntNdArray slice1 = matrix3d .slice (Indices .all (), Indices .slice (null , 3 ), Indices .all ());
139
+
140
+ assertEquals (Shape .of (5 , 3 , 5 ), slice1 .shape ());
141
+ assertEquals (0 , slice1 .getInt (0 , 0 , 0 ));
142
+ assertEquals (1 , slice1 .getInt (0 , 0 , 1 ));
143
+ assertEquals (2 , slice1 .getInt (0 , 1 , 2 ));
144
+
145
+ IntNdArray slice2 = matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (1 , 4 ));
146
+
147
+ assertEquals (Shape .of (5 , 4 , 3 ), slice2 .shape ());
148
+ assertEquals (1 , slice2 .getInt (0 , 0 , 0 ));
149
+ assertEquals (3 , slice2 .getInt (0 , 0 , 2 ));
150
+ assertEquals (2 , slice2 .getInt (0 , 1 , 1 ));
151
+
152
+ assertEquals (slice2 , matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (1 , -1 )));
153
+
154
+ assertEquals (slice2 , matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (-4 , -1 )));
155
+
156
+ assertEquals (Shape .of (5 , 4 , 0 ), matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (1 , 4 , -2 )).shape ());
157
+
158
+ IntNdArray slice3 = matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (4 , 1 , -2 ));
159
+
160
+ assertEquals (Shape .of (5 , 4 , 2 ), slice3 .shape ());
161
+ assertEquals (4 , slice3 .getInt (0 , 0 , 0 ));
162
+ assertEquals (2 , slice3 .getInt (0 , 1 , 1 ));
163
+
164
+ assertEquals (slice3 , matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (-1 , 1 , -2 )));
165
+
166
+ assertEquals (slice3 , matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (-1 , -4 , -2 )));
167
+
168
+ IntNdArray slice4 = matrix3d .slice (Indices .all (), Indices .all (), Indices .slice (null , null , -1 ));
169
+
170
+ assertEquals (Shape .of (5 , 4 , 5 ), slice4 .shape ());
171
+ assertEquals (4 , slice4 .getInt (0 , 0 , 0 ));
172
+ assertEquals (3 , slice4 .getInt (0 , 0 , 1 ));
173
+ assertEquals (2 , slice4 .getInt (0 , 1 , 2 ));
174
+ }
175
+
176
+ @ Test
177
+ public void testAt (){
178
+ IntNdArray matrix3d = NdArrays .ofInts (Shape .of (5 , 4 , 5 ));
179
+
180
+ matrix3d .scalars ().forEachIndexed ((coords , scalar ) ->
181
+ scalar .setInt ((int )coords [2 ])
182
+ );
183
+
184
+ IntNdArray slice1 = matrix3d .slice (Indices .all (), Indices .all (), Indices .at (0 ));
185
+
186
+ assertEquals (Shape .of (5 , 4 ), slice1 .shape ());
187
+ assertEquals (0 , slice1 .getInt (0 , 0 ));
188
+
189
+ IntNdArray slice2 = matrix3d .slice (Indices .all (), Indices .all (), Indices .at (3 ));
190
+
191
+ assertEquals (Shape .of (5 , 4 ), slice2 .shape ());
192
+ assertEquals (3 , slice2 .getInt (0 , 0 ));
193
+
194
+ IntNdArray slice3 = matrix3d .slice (Indices .all (), Indices .all (), Indices .at (-3 ));
195
+
196
+ assertEquals (Shape .of (5 , 4 ), slice3 .shape ());
197
+ assertEquals (2 , slice3 .getInt (0 , 0 ));
198
+
199
+ IntNdArray slice4 = matrix3d .slice (Indices .all (), Indices .all (), Indices .at (-3 , true ));
200
+
201
+ assertEquals (Shape .of (5 , 4 , 1 ), slice4 .shape ());
202
+ assertEquals (2 , slice4 .getInt (0 , 0 , 0 ));
203
+ }
204
+
45
205
}
0 commit comments