14
14
=======================================================================*/
15
15
package org .tensorflow .framework .metrics .impl ;
16
16
17
+ import static org .junit .jupiter .api .Assertions .assertThrows ;
18
+
19
+ import java .util .List ;
17
20
import org .junit .jupiter .api .Test ;
18
21
import org .tensorflow .Operand ;
19
22
import org .tensorflow .Tensor ;
23
+ import org .tensorflow .TensorScope ;
20
24
import org .tensorflow .framework .utils .TestSession ;
21
25
import org .tensorflow .op .Op ;
22
26
import org .tensorflow .op .Ops ;
26
30
import org .tensorflow .types .TInt64 ;
27
31
import org .tensorflow .types .family .TNumber ;
28
32
29
- import java .util .List ;
30
-
31
- import static org .junit .jupiter .api .Assertions .assertThrows ;
32
-
33
33
public class AssertBroadcastableTest {
34
34
35
35
private final TestSession .Mode tfMode = TestSession .Mode .GRAPH ;
36
36
37
37
int [][][] valueArrayI =
38
- new int [][][] {
39
- {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
40
- {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
41
- {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
38
+ new int [][][]{
39
+ {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
40
+ {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
41
+ {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
42
42
};
43
43
long [][][] valueArrayL =
44
- new long [][][] {
45
- {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
46
- {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
47
- {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
44
+ new long [][][]{
45
+ {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
46
+ {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
47
+ {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
48
48
};
49
49
float [][][] valueArrayF =
50
- new float [][][] {
51
- {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
52
- {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
53
- {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
50
+ new float [][][]{
51
+ {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
52
+ {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
53
+ {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
54
54
};
55
55
double [][][] valueArrayD =
56
- new double [][][] {
57
- {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
58
- {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
59
- {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
56
+ new double [][][]{
57
+ {{1 , 2 , 3 , 4 }, {5 , 6 , 7 , 8 }},
58
+ {{9 , 10 , 11 , 12 }, {13 , 14 , 15 , 16 }},
59
+ {{17 , 18 , 19 , 20 }, {21 , 22 , 23 , 24 }}
60
60
};
61
61
62
62
private <T extends TNumber > void testValid (
@@ -68,10 +68,11 @@ private <T extends TNumber> void testValid(
68
68
Operand <T > weightsPlaceholder = tf .placeholder (type );
69
69
Operand <T > valuesPlaceholder = tf .placeholder (type );
70
70
71
- List <Tensor > tensors =
72
- testSession .getGraphSession ().runner ().fetch (weights ).fetch (values ).run ();
73
- try (Tensor weightsTensor = tensors .get (0 );
74
- Tensor valuesTensor = tensors .get (1 )) {
71
+ try (TensorScope scope = new TensorScope ()) {
72
+ List <Tensor > tensors =
73
+ testSession .getGraphSession ().runner ().fetch (weights ).fetch (values ).run (scope );
74
+ Tensor weightsTensor = tensors .get (0 );
75
+ Tensor valuesTensor = tensors .get (1 );
75
76
Op dynamicOp = MetricsHelper .assertBroadcastable (tf , weightsPlaceholder , valuesPlaceholder );
76
77
77
78
testSession
@@ -80,7 +81,7 @@ private <T extends TNumber> void testValid(
80
81
.feed (weightsPlaceholder , weightsTensor )
81
82
.feed (valuesPlaceholder , valuesTensor )
82
83
.addTarget (dynamicOp )
83
- .run ();
84
+ .run (scope );
84
85
}
85
86
}
86
87
@@ -103,7 +104,7 @@ public void test1x1x1() {
103
104
Ops tf = testSession .getTF ();
104
105
105
106
Operand <TFloat64 > values = tf .constant (valueArrayD );
106
- Operand <TFloat64 > weights = tf .constant (new double [][][] {{{5 }}});
107
+ Operand <TFloat64 > weights = tf .constant (new double [][][]{{{5 }}});
107
108
testValid (testSession , tf , weights , values , TFloat64 .class );
108
109
}
109
110
}
@@ -114,7 +115,7 @@ public void test1x1xN() {
114
115
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
115
116
Ops tf = testSession .getTF ();
116
117
Operand <TInt64 > values = tf .constant (valueArrayL );
117
- Operand <TInt64 > weights = tf .constant (new long [][][] {{{5 , 7 , 11 , 3 }}});
118
+ Operand <TInt64 > weights = tf .constant (new long [][][]{{{5 , 7 , 11 , 3 }}});
118
119
testValid (testSession , tf , weights , values , TInt64 .class );
119
120
}
120
121
}
@@ -125,7 +126,7 @@ public void test1xNx1() {
125
126
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
126
127
Ops tf = testSession .getTF ();
127
128
Operand <TInt32 > values = tf .constant (valueArrayI );
128
- Operand <TInt32 > weights = tf .constant (new int [][][] {{{5 }, {11 }}});
129
+ Operand <TInt32 > weights = tf .constant (new int [][][]{{{5 }, {11 }}});
129
130
testValid (testSession , tf , weights , values , TInt32 .class );
130
131
}
131
132
}
@@ -137,7 +138,7 @@ public void test1xNxN() {
137
138
Ops tf = testSession .getTF ();
138
139
139
140
Operand <TInt32 > values = tf .constant (valueArrayI );
140
- Operand <TInt32 > weights = tf .constant (new int [][][] {{{5 , 7 , 11 , 3 }, {2 , 13 , 7 , 5 }}});
141
+ Operand <TInt32 > weights = tf .constant (new int [][][]{{{5 , 7 , 11 , 3 }, {2 , 13 , 7 , 5 }}});
141
142
testValid (testSession , tf , weights , values , TInt32 .class );
142
143
}
143
144
}
@@ -149,7 +150,7 @@ public void testNx1x1() {
149
150
Ops tf = testSession .getTF ();
150
151
151
152
Operand <TInt32 > values = tf .constant (valueArrayI );
152
- Operand <TInt32 > weights = tf .constant (new int [][][] {{{5 }}, {{7 }}, {{11 }}});
153
+ Operand <TInt32 > weights = tf .constant (new int [][][]{{{5 }}, {{7 }}, {{11 }}});
153
154
testValid (testSession , tf , weights , values , TInt32 .class );
154
155
}
155
156
}
@@ -162,7 +163,7 @@ public void testNx1xN() {
162
163
163
164
Operand <TInt32 > values = tf .constant (valueArrayI );
164
165
Operand <TInt32 > weights =
165
- tf .constant (new int [][][] {{{5 , 7 , 11 , 3 }}, {{2 , 12 , 7 , 5 }}, {{2 , 17 , 11 , 3 }}});
166
+ tf .constant (new int [][][]{{{5 , 7 , 11 , 3 }}, {{2 , 12 , 7 , 5 }}, {{2 , 17 , 11 , 3 }}});
166
167
testValid (testSession , tf , weights , values , TInt32 .class );
167
168
}
168
169
}
@@ -176,10 +177,10 @@ public void testNxNxN() {
176
177
Operand <TInt32 > values = tf .constant (valueArrayI );
177
178
Operand <TInt32 > weights =
178
179
tf .constant (
179
- new int [][][] {
180
- {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }},
181
- {{2 , 17 , 11 , 3 }, {2 , 17 , 11 , 3 }},
182
- {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }}
180
+ new int [][][]{
181
+ {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }},
182
+ {{2 , 17 , 11 , 3 }, {2 , 17 , 11 , 3 }},
183
+ {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }}
183
184
});
184
185
testValid (testSession , tf , weights , values , TInt32 .class );
185
186
}
@@ -199,7 +200,7 @@ public void testInvalid1x1() {
199
200
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
200
201
Ops tf = testSession .getTF ();
201
202
Operand <TInt32 > values = tf .constant (valueArrayI );
202
- Operand <TInt32 > weights = tf .constant (new int [][] {{5 }});
203
+ Operand <TInt32 > weights = tf .constant (new int [][]{{5 }});
203
204
testValid (testSession , tf , weights , values , TInt32 .class );
204
205
}
205
206
});
@@ -213,7 +214,7 @@ public void testInvalidPrefixMatch() {
213
214
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
214
215
Ops tf = testSession .getTF ();
215
216
Operand <TInt32 > values = tf .constant (valueArrayI );
216
- Operand <TInt32 > weights = tf .constant (new int [][] {{5 , 7 }, {11 , 3 }, {2 , 12 }});
217
+ Operand <TInt32 > weights = tf .constant (new int [][]{{5 , 7 }, {11 , 3 }, {2 , 12 }});
217
218
testValid (testSession , tf , weights , values , TInt32 .class );
218
219
}
219
220
});
@@ -227,7 +228,7 @@ public void testInvalidSuffixMatch() {
227
228
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
228
229
Ops tf = testSession .getTF ();
229
230
Operand <TInt32 > values = tf .constant (valueArrayI );
230
- Operand <TInt32 > weights = tf .constant (new int [][] {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }});
231
+ Operand <TInt32 > weights = tf .constant (new int [][]{{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }});
231
232
testValid (testSession , tf , weights , values , TInt32 .class );
232
233
}
233
234
});
@@ -241,7 +242,7 @@ public void testInvalidOnesExtraDim() {
241
242
try (TestSession testSession = TestSession .createTestSession (tfMode )) {
242
243
Ops tf = testSession .getTF ();
243
244
Operand <TInt32 > values = tf .constant (valueArrayI );
244
- Operand <TInt32 > weights = tf .constant (new int [][][][] {{{{5 }}}});
245
+ Operand <TInt32 > weights = tf .constant (new int [][][][]{{{{5 }}}});
245
246
testValid (testSession , tf , weights , values , TInt32 .class );
246
247
}
247
248
});
@@ -258,10 +259,10 @@ public void testInvalidPrefixMatchExtraDim() {
258
259
259
260
Operand <TInt32 > weights =
260
261
tf .constant (
261
- new int [][][][] {
262
- {{{5 }, {7 }, {11 }, {3 }}, {{2 }, {12 }, {7 }, {5 }}},
263
- {{{2 }, {17 }, {11 }, {3 }}, {{2 }, {17 }, {11 }, {3 }}},
264
- {{{5 }, {7 }, {11 }, {3 }}, {{2 }, {12 }, {7 }, {5 }}}
262
+ new int [][][][]{
263
+ {{{5 }, {7 }, {11 }, {3 }}, {{2 }, {12 }, {7 }, {5 }}},
264
+ {{{2 }, {17 }, {11 }, {3 }}, {{2 }, {17 }, {11 }, {3 }}},
265
+ {{{5 }, {7 }, {11 }, {3 }}, {{2 }, {12 }, {7 }, {5 }}}
265
266
});
266
267
testValid (testSession , tf , weights , values , TInt32 .class );
267
268
}
@@ -278,12 +279,12 @@ public void testInvalidSuffixMatchExtraDim() {
278
279
Operand <TInt32 > values = tf .constant (valueArrayI );
279
280
Operand <TInt32 > weights =
280
281
tf .constant (
281
- new int [][][][] {
282
- {
283
- {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }},
284
- {{2 , 17 , 11 , 3 }, {2 , 17 , 11 , 3 }},
285
- {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }}
286
- }
282
+ new int [][][][]{
283
+ {
284
+ {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }},
285
+ {{2 , 17 , 11 , 3 }, {2 , 17 , 11 , 3 }},
286
+ {{5 , 7 , 11 , 3 }, {2 , 12 , 7 , 5 }}
287
+ }
287
288
});
288
289
testValid (testSession , tf , weights , values , TInt32 .class );
289
290
}
0 commit comments