Skip to content

Commit 2cb6de8

Browse files
committed
Accept partially known shapes in boolean mask/updates
1 parent 5cf1568 commit 2cb6de8

File tree

5 files changed

+149
-2
lines changed

5 files changed

+149
-2
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Result.java

+70
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import java.util.Set;
2828
import java.util.logging.Level;
2929
import java.util.logging.Logger;
30+
import java.util.stream.Collectors;
3031
import org.tensorflow.exceptions.TensorFlowException;
3132
import org.tensorflow.proto.RunMetadata;
33+
import org.tensorflow.types.family.TType;
3234

3335
/**
3436
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
@@ -115,6 +117,31 @@ public Tensor get(int index) {
115117
}
116118
}
117119

120+
/**
121+
* Gets the value from the container at the specified index, casting it to a given tensor type
122+
*
123+
* <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
124+
* IndexOutOfBoundsException} if the index is invalid.
125+
*
126+
* @param index The index to lookup.
127+
* @param type tensor type
128+
* @return The value at the index.
129+
*/
130+
public <T extends TType> T get(int index, Class<T> type) {
131+
if (!closed) {
132+
var tensor = list.get(index);
133+
try {
134+
return type.cast(tensor);
135+
} catch (ClassCastException e) {
136+
var tensorName = map.keySet().stream().collect(Collectors.toList()).get(index);
137+
throw new IllegalArgumentException(
138+
buildInvalidTensorTypeExceptionMessage(tensor, tensorName, type));
139+
}
140+
} else {
141+
throw new IllegalStateException("Result is closed");
142+
}
143+
}
144+
118145
/**
119146
* Gets the value from the container assuming it's not been closed.
120147
*
@@ -131,6 +158,33 @@ public Optional<Tensor> get(String key) {
131158
}
132159
}
133160

161+
/**
162+
* Gets the value from the container, assuming it's not been closed, casting it to a given tensor
163+
* type.
164+
*
165+
* <p>Throws {@link IllegalStateException} if the container has been closed.
166+
*
167+
* @param key The key to lookup.
168+
* @param type tensor type
169+
* @return Optional.of the value if it exists.
170+
*/
171+
public <T extends TType> Optional<T> get(String key, Class<T> type) {
172+
if (!closed) {
173+
return Optional.ofNullable(map.get(key))
174+
.map(
175+
t -> {
176+
try {
177+
return type.cast(t);
178+
} catch (ClassCastException e) {
179+
throw new IllegalArgumentException(
180+
buildInvalidTensorTypeExceptionMessage(t, key, type));
181+
}
182+
});
183+
} else {
184+
throw new IllegalStateException("Result is closed");
185+
}
186+
}
187+
134188
/**
135189
* Metadata about the run.
136190
*
@@ -196,4 +250,20 @@ public Optional<RunMetadata> getMetadata() {
196250
private boolean closed;
197251

198252
private static final Logger logger = Logger.getLogger(Result.class.getName());
253+
254+
private String buildInvalidTensorTypeExceptionMessage(
255+
Tensor tensor, String tensorName, Class<? extends TType> requestedType) {
256+
String actualTypeName =
257+
tensor instanceof TType
258+
? ((TType) tensor).type().getSimpleName()
259+
: tensor.getClass().getName();
260+
throw new IllegalStateException(
261+
"Tensor \""
262+
+ tensorName
263+
+ "\" of type \""
264+
+ actualTypeName
265+
+ "\" is not compatible with requested type \""
266+
+ requestedType.getSimpleName()
267+
+ "\"");
268+
}
199269
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMask.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ public static <T extends TType> Operand<T> create(
7878
if (maskShape.numDimensions() == 0) {
7979
throw new IllegalArgumentException("Mask cannot be a scalar.");
8080
}
81-
if (maskShape.hasUnknownDimension()) {
81+
if (maskShape.isUnknown()) {
8282
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
8383
}
8484

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/BooleanMaskUpdate.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ public static <T extends TType> Operand<T> create(
8686
if (maskShape.numDimensions() == 0) {
8787
throw new IllegalArgumentException("Mask cannot be a scalar.");
8888
}
89-
if (maskShape.hasUnknownDimension()) {
89+
if (maskShape.isUnknown()) {
9090
throw new IllegalArgumentException("Mask cannot have unknown number of dimensions");
9191
}
9292

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java

+36
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -66,4 +67,39 @@ public void testBooleanMask() {
6667
}
6768
}
6869
}
70+
71+
@Test
72+
public void testBooleanMaskWithPartiallyUnknownShape() {
73+
try (Graph g = new Graph();
74+
Session sess = new Session(g)) {
75+
Scope scope = new OpScope(g);
76+
77+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
78+
Placeholder<TBool> inputMask =
79+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
80+
81+
Operand<TInt32> output = BooleanMask.create(scope, input, inputMask);
82+
83+
try (TBool mask = TBool.vectorOf(true, false, false, true);
84+
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
85+
// expected shape from Python tensorflow
86+
assertEquals(Shape.of(2), result.shape());
87+
assertEquals(1, result.getInt(0));
88+
assertEquals(4, result.getInt(1));
89+
}
90+
}
91+
}
92+
93+
@Test
94+
public void testBooleanMaskWithUnknownShape() {
95+
try (Graph g = new Graph()) {
96+
Scope scope = new OpScope(g);
97+
98+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
99+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
100+
101+
assertThrows(
102+
IllegalArgumentException.class, () -> BooleanMask.create(scope, input, inputMask));
103+
}
104+
}
69105
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java

+41
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.tensorflow.op.core;
1818

1919
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertThrows;
2021

2122
import org.junit.jupiter.api.Test;
2223
import org.tensorflow.Graph;
@@ -151,4 +152,44 @@ public void testBooleanMaskUpdateAxis() {
151152
}
152153
}
153154
}
155+
156+
@Test
157+
public void testBooleanMaskUpdateWithPartiallyUnknownShape() {
158+
try (Graph g = new Graph();
159+
Session sess = new Session(g)) {
160+
Scope scope = new OpScope(g);
161+
162+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
163+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
164+
Placeholder<TBool> inputMask =
165+
Placeholder.create(scope, TBool.class, Placeholder.shape(Shape.of(Shape.UNKNOWN_SIZE)));
166+
167+
Operand<TInt32> output = BooleanMaskUpdate.create(scope, input, inputMask, updates);
168+
169+
try (TBool mask = TBool.vectorOf(false, true, false, true);
170+
TInt32 result = (TInt32) sess.runner().feed(inputMask, mask).fetch(output).run().get(0)) {
171+
// expected shape from Python tensorflow
172+
assertEquals(Shape.of(4), result.shape());
173+
assertEquals(1, result.getInt(0));
174+
assertEquals(-1, result.getInt(1));
175+
assertEquals(3, result.getInt(2));
176+
assertEquals(2, result.getInt(3));
177+
}
178+
}
179+
}
180+
181+
@Test
182+
public void testBooleanMaskUpdateWithUnknownShape() {
183+
try (Graph g = new Graph()) {
184+
Scope scope = new OpScope(g);
185+
186+
Operand<TInt32> input = Constant.arrayOf(scope, 1, 2, 3, 4);
187+
Operand<TInt32> updates = Constant.arrayOf(scope, -1, 2);
188+
Placeholder<TBool> inputMask = Placeholder.create(scope, TBool.class);
189+
190+
assertThrows(
191+
IllegalArgumentException.class,
192+
() -> BooleanMaskUpdate.create(scope, input, inputMask, updates));
193+
}
194+
}
154195
}

0 commit comments

Comments
 (0)