Skip to content

Commit a66ca55

Browse files
committed
More framework fixes
Signed-off-by: Ryan Nett <[email protected]>
1 parent f4218af commit a66ca55

File tree

2 files changed

+114
-112
lines changed

2 files changed

+114
-112
lines changed

tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,13 @@
1414
=======================================================================*/
1515
package org.tensorflow.framework.metrics.impl;
1616

17+
import static org.junit.jupiter.api.Assertions.assertThrows;
18+
19+
import java.util.List;
1720
import org.junit.jupiter.api.Test;
1821
import org.tensorflow.Operand;
1922
import org.tensorflow.Tensor;
23+
import org.tensorflow.TensorScope;
2024
import org.tensorflow.framework.utils.TestSession;
2125
import org.tensorflow.op.Op;
2226
import org.tensorflow.op.Ops;
@@ -26,37 +30,33 @@
2630
import org.tensorflow.types.TInt64;
2731
import org.tensorflow.types.family.TNumber;
2832

29-
import java.util.List;
30-
31-
import static org.junit.jupiter.api.Assertions.assertThrows;
32-
3333
public class AssertBroadcastableTest {
3434

3535
private final TestSession.Mode tfMode = TestSession.Mode.GRAPH;
3636

3737
int[][][] valueArrayI =
38-
new int[][][] {
39-
{{1, 2, 3, 4}, {5, 6, 7, 8}},
40-
{{9, 10, 11, 12}, {13, 14, 15, 16}},
41-
{{17, 18, 19, 20}, {21, 22, 23, 24}}
38+
new int[][][]{
39+
{{1, 2, 3, 4}, {5, 6, 7, 8}},
40+
{{9, 10, 11, 12}, {13, 14, 15, 16}},
41+
{{17, 18, 19, 20}, {21, 22, 23, 24}}
4242
};
4343
long[][][] valueArrayL =
44-
new long[][][] {
45-
{{1, 2, 3, 4}, {5, 6, 7, 8}},
46-
{{9, 10, 11, 12}, {13, 14, 15, 16}},
47-
{{17, 18, 19, 20}, {21, 22, 23, 24}}
44+
new long[][][]{
45+
{{1, 2, 3, 4}, {5, 6, 7, 8}},
46+
{{9, 10, 11, 12}, {13, 14, 15, 16}},
47+
{{17, 18, 19, 20}, {21, 22, 23, 24}}
4848
};
4949
float[][][] valueArrayF =
50-
new float[][][] {
51-
{{1, 2, 3, 4}, {5, 6, 7, 8}},
52-
{{9, 10, 11, 12}, {13, 14, 15, 16}},
53-
{{17, 18, 19, 20}, {21, 22, 23, 24}}
50+
new float[][][]{
51+
{{1, 2, 3, 4}, {5, 6, 7, 8}},
52+
{{9, 10, 11, 12}, {13, 14, 15, 16}},
53+
{{17, 18, 19, 20}, {21, 22, 23, 24}}
5454
};
5555
double[][][] valueArrayD =
56-
new double[][][] {
57-
{{1, 2, 3, 4}, {5, 6, 7, 8}},
58-
{{9, 10, 11, 12}, {13, 14, 15, 16}},
59-
{{17, 18, 19, 20}, {21, 22, 23, 24}}
56+
new double[][][]{
57+
{{1, 2, 3, 4}, {5, 6, 7, 8}},
58+
{{9, 10, 11, 12}, {13, 14, 15, 16}},
59+
{{17, 18, 19, 20}, {21, 22, 23, 24}}
6060
};
6161

6262
private <T extends TNumber> void testValid(
@@ -68,10 +68,11 @@ private <T extends TNumber> void testValid(
6868
Operand<T> weightsPlaceholder = tf.placeholder(type);
6969
Operand<T> valuesPlaceholder = tf.placeholder(type);
7070

71-
List<Tensor> tensors =
72-
testSession.getGraphSession().runner().fetch(weights).fetch(values).run();
73-
try (Tensor weightsTensor = tensors.get(0);
74-
Tensor valuesTensor = tensors.get(1)) {
71+
try (TensorScope scope = new TensorScope()) {
72+
List<Tensor> tensors =
73+
testSession.getGraphSession().runner().fetch(weights).fetch(values).run(scope);
74+
Tensor weightsTensor = tensors.get(0);
75+
Tensor valuesTensor = tensors.get(1);
7576
Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder);
7677

7778
testSession
@@ -80,7 +81,7 @@ private <T extends TNumber> void testValid(
8081
.feed(weightsPlaceholder, weightsTensor)
8182
.feed(valuesPlaceholder, valuesTensor)
8283
.addTarget(dynamicOp)
83-
.run();
84+
.run(scope);
8485
}
8586
}
8687

@@ -103,7 +104,7 @@ public void test1x1x1() {
103104
Ops tf = testSession.getTF();
104105

105106
Operand<TFloat64> values = tf.constant(valueArrayD);
106-
Operand<TFloat64> weights = tf.constant(new double[][][] {{{5}}});
107+
Operand<TFloat64> weights = tf.constant(new double[][][]{{{5}}});
107108
testValid(testSession, tf, weights, values, TFloat64.class);
108109
}
109110
}
@@ -114,7 +115,7 @@ public void test1x1xN() {
114115
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
115116
Ops tf = testSession.getTF();
116117
Operand<TInt64> values = tf.constant(valueArrayL);
117-
Operand<TInt64> weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}});
118+
Operand<TInt64> weights = tf.constant(new long[][][]{{{5, 7, 11, 3}}});
118119
testValid(testSession, tf, weights, values, TInt64.class);
119120
}
120121
}
@@ -125,7 +126,7 @@ public void test1xNx1() {
125126
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
126127
Ops tf = testSession.getTF();
127128
Operand<TInt32> values = tf.constant(valueArrayI);
128-
Operand<TInt32> weights = tf.constant(new int[][][] {{{5}, {11}}});
129+
Operand<TInt32> weights = tf.constant(new int[][][]{{{5}, {11}}});
129130
testValid(testSession, tf, weights, values, TInt32.class);
130131
}
131132
}
@@ -137,7 +138,7 @@ public void test1xNxN() {
137138
Ops tf = testSession.getTF();
138139

139140
Operand<TInt32> values = tf.constant(valueArrayI);
140-
Operand<TInt32> weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}});
141+
Operand<TInt32> weights = tf.constant(new int[][][]{{{5, 7, 11, 3}, {2, 13, 7, 5}}});
141142
testValid(testSession, tf, weights, values, TInt32.class);
142143
}
143144
}
@@ -149,7 +150,7 @@ public void testNx1x1() {
149150
Ops tf = testSession.getTF();
150151

151152
Operand<TInt32> values = tf.constant(valueArrayI);
152-
Operand<TInt32> weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}});
153+
Operand<TInt32> weights = tf.constant(new int[][][]{{{5}}, {{7}}, {{11}}});
153154
testValid(testSession, tf, weights, values, TInt32.class);
154155
}
155156
}
@@ -162,7 +163,7 @@ public void testNx1xN() {
162163

163164
Operand<TInt32> values = tf.constant(valueArrayI);
164165
Operand<TInt32> weights =
165-
tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}});
166+
tf.constant(new int[][][]{{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}});
166167
testValid(testSession, tf, weights, values, TInt32.class);
167168
}
168169
}
@@ -176,10 +177,10 @@ public void testNxNxN() {
176177
Operand<TInt32> values = tf.constant(valueArrayI);
177178
Operand<TInt32> weights =
178179
tf.constant(
179-
new int[][][] {
180-
{{5, 7, 11, 3}, {2, 12, 7, 5}},
181-
{{2, 17, 11, 3}, {2, 17, 11, 3}},
182-
{{5, 7, 11, 3}, {2, 12, 7, 5}}
180+
new int[][][]{
181+
{{5, 7, 11, 3}, {2, 12, 7, 5}},
182+
{{2, 17, 11, 3}, {2, 17, 11, 3}},
183+
{{5, 7, 11, 3}, {2, 12, 7, 5}}
183184
});
184185
testValid(testSession, tf, weights, values, TInt32.class);
185186
}
@@ -199,7 +200,7 @@ public void testInvalid1x1() {
199200
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
200201
Ops tf = testSession.getTF();
201202
Operand<TInt32> values = tf.constant(valueArrayI);
202-
Operand<TInt32> weights = tf.constant(new int[][] {{5}});
203+
Operand<TInt32> weights = tf.constant(new int[][]{{5}});
203204
testValid(testSession, tf, weights, values, TInt32.class);
204205
}
205206
});
@@ -213,7 +214,7 @@ public void testInvalidPrefixMatch() {
213214
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
214215
Ops tf = testSession.getTF();
215216
Operand<TInt32> values = tf.constant(valueArrayI);
216-
Operand<TInt32> weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}});
217+
Operand<TInt32> weights = tf.constant(new int[][]{{5, 7}, {11, 3}, {2, 12}});
217218
testValid(testSession, tf, weights, values, TInt32.class);
218219
}
219220
});
@@ -227,7 +228,7 @@ public void testInvalidSuffixMatch() {
227228
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
228229
Ops tf = testSession.getTF();
229230
Operand<TInt32> values = tf.constant(valueArrayI);
230-
Operand<TInt32> weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}});
231+
Operand<TInt32> weights = tf.constant(new int[][]{{5, 7, 11, 3}, {2, 12, 7, 5}});
231232
testValid(testSession, tf, weights, values, TInt32.class);
232233
}
233234
});
@@ -241,7 +242,7 @@ public void testInvalidOnesExtraDim() {
241242
try (TestSession testSession = TestSession.createTestSession(tfMode)) {
242243
Ops tf = testSession.getTF();
243244
Operand<TInt32> values = tf.constant(valueArrayI);
244-
Operand<TInt32> weights = tf.constant(new int[][][][] {{{{5}}}});
245+
Operand<TInt32> weights = tf.constant(new int[][][][]{{{{5}}}});
245246
testValid(testSession, tf, weights, values, TInt32.class);
246247
}
247248
});
@@ -258,10 +259,10 @@ public void testInvalidPrefixMatchExtraDim() {
258259

259260
Operand<TInt32> weights =
260261
tf.constant(
261-
new int[][][][] {
262-
{{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}},
263-
{{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}},
264-
{{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}
262+
new int[][][][]{
263+
{{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}},
264+
{{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}},
265+
{{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}
265266
});
266267
testValid(testSession, tf, weights, values, TInt32.class);
267268
}
@@ -278,12 +279,12 @@ public void testInvalidSuffixMatchExtraDim() {
278279
Operand<TInt32> values = tf.constant(valueArrayI);
279280
Operand<TInt32> weights =
280281
tf.constant(
281-
new int[][][][] {
282-
{
283-
{{5, 7, 11, 3}, {2, 12, 7, 5}},
284-
{{2, 17, 11, 3}, {2, 17, 11, 3}},
285-
{{5, 7, 11, 3}, {2, 12, 7, 5}}
286-
}
282+
new int[][][][]{
283+
{
284+
{{5, 7, 11, 3}, {2, 12, 7, 5}},
285+
{{2, 17, 11, 3}, {2, 17, 11, 3}},
286+
{{5, 7, 11, 3}, {2, 12, 7, 5}}
287+
}
287288
});
288289
testValid(testSession, tf, weights, values, TInt32.class);
289290
}

0 commit comments

Comments
 (0)