Skip to content

Commit 71654ca

Browse files
karllessardCraigacp
authored andcommitted
Fix broadcastMask/Update
Accept partially unknown shaped mask
1 parent 5cf1568 commit 71654ca

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public static <T extends TType> Operand<T> create(
7878
if (maskShape.numDimensions() == 0) {
7979
throw new IllegalArgumentException("Mask cannot be a scalar.");
8080
}
81-
if (maskShape.hasUnknownDimension()) {
81+
if (maskShape.isUnknown()) {
8282
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
8383
}
8484

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static <T extends TType> Operand<T> create(
8686
if (maskShape.numDimensions() == 0) {
8787
throw new IllegalArgumentException("Mask cannot be a scalar.");
8888
}
89-
if (maskShape.hasUnknownDimension()) {
89+
if (maskShape.isUnknown()) {
9090
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
9191
}
9292

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java

+36
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -66,4 +67,39 @@ public void testBooleanMask() {
6667
}
6768
}
6869
}
70+
71+
@Test
72+
public void testBooleanMaskWithPartiallyUnknownShape() {
73+
try (Graph g = new Graph();
74+
Session sess = new Session(g)) {
75+
Scope scope = new OpScope(g);
76+
77+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
78+
Placeholder<TBool> inputMask =
79+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
80+
81+
Operand<TInt32> output = BooleanMask.create(scope, input, inputMask);
82+
83+
try (TBool mask = TBool.vectorOf(true, false, false, true);
84+
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
85+
// expected shape from Python tensorflow
86+
assertEquals(Shape.of(2), result.shape());
87+
assertEquals(1, result.getInt(0));
88+
assertEquals(4, result.getInt(1));
89+
}
90+
}
91+
}
92+
93+
@Test
94+
public void testBooleanMaskWithUnknownShape() {
95+
try (Graph g = new Graph()) {
96+
Scope scope = new OpScope(g);
97+
98+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
99+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
100+
101+
assertThrows(
102+
IllegalArgumentException.class, () -> BooleanMask.create(scope, input, inputMask));
103+
}
104+
}
69105
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

+41
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -151,4 +152,44 @@ public void testBooleanMaskUpdateAxis() {
151152
}
152153
}
153154
}
155+
156+
@Test
157+
public void testBooleanMaskUpdateWithPartiallyUnknownShape() {
158+
try (Graph g = new Graph();
159+
Session sess = new Session(g)) {
160+
Scope scope = new OpScope(g);
161+
162+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
163+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
164+
Placeholder<TBool> inputMask =
165+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
166+
167+
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, inputMask, updates);
168+
169+
try (TBool mask = TBool.vectorOf(false, true, false, true);
170+
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
171+
// expected shape from Python tensorflow
172+
assertEquals(Shape.of(4), result.shape());
173+
assertEquals(1, result.getInt(0));
174+
assertEquals(-1, result.getInt(1));
175+
assertEquals(3, result.getInt(2));
176+
assertEquals(2, result.getInt(3));
177+
}
178+
}
179+
}
180+
181+
@Test
182+
public void testBooleanMaskUpdateWithUnknownShape() {
183+
try (Graph g = new Graph()) {
184+
Scope scope = new OpScope(g);
185+
186+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
187+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
188+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
189+
190+
assertThrows(
191+
IllegalArgumentException.class,
192+
() -> BooleanMaskUpdate.create(scope, input, inputMask, updates));
193+
}
194+
}
154195
}

0 commit comments

Comments
 (0)