27
27
import java .util .Set ;
28
28
import java .util .logging .Level ;
29
29
import java .util .logging .Logger ;
30
+ import java .util .stream .Collectors ;
30
31
import org .tensorflow .exceptions .TensorFlowException ;
31
32
import org .tensorflow .proto .RunMetadata ;
33
+ import org .tensorflow .types .family .TType ;
32
34
33
35
/**
34
36
* An {@link AutoCloseable} wrapper around a {@link Map} containing {@link Tensor}s.
@@ -115,6 +117,31 @@ public Tensor get(int index) {
115
117
}
116
118
}
117
119
120
+ /**
121
+ * Gets the value from the container at the specified index, casting it to a given tensor type
122
+ *
123
+ * <p>Throws {@link IllegalStateException} if the container has been closed, and {@link
124
+ * IndexOutOfBoundsException} if the index is invalid.
125
+ *
126
+ * @param index The index to lookup.
127
+ * @param type tensor type
128
+ * @return The value at the index.
129
+ */
130
+ public <T extends TType > T get (int index , Class <T > type ) {
131
+ if (!closed ) {
132
+ var tensor = list .get (index );
133
+ try {
134
+ return type .cast (tensor );
135
+ } catch (ClassCastException e ) {
136
+ var tensorName = map .keySet ().stream ().collect (Collectors .toList ()).get (index );
137
+ throw new IllegalArgumentException (
138
+ buildInvalidTensorTypeExceptionMessage (tensor , tensorName , type ));
139
+ }
140
+ } else {
141
+ throw new IllegalStateException ("Result is closed" );
142
+ }
143
+ }
144
+
118
145
/**
119
146
* Gets the value from the container assuming it's not been closed.
120
147
*
@@ -131,6 +158,33 @@ public Optional<Tensor> get(String key) {
131
158
}
132
159
}
133
160
161
+ /**
162
+ * Gets the value from the container, assuming it's not been closed, casting it to a given tensor
163
+ * type.
164
+ *
165
+ * <p>Throws {@link IllegalStateException} if the container has been closed.
166
+ *
167
+ * @param key The key to lookup.
168
+ * @param type tensor type
169
+ * @return Optional.of the value if it exists.
170
+ */
171
+ public <T extends TType > Optional <T > get (String key , Class <T > type ) {
172
+ if (!closed ) {
173
+ return Optional .ofNullable (map .get (key ))
174
+ .map (
175
+ t -> {
176
+ try {
177
+ return type .cast (t );
178
+ } catch (ClassCastException e ) {
179
+ throw new IllegalArgumentException (
180
+ buildInvalidTensorTypeExceptionMessage (t , key , type ));
181
+ }
182
+ });
183
+ } else {
184
+ throw new IllegalStateException ("Result is closed" );
185
+ }
186
+ }
187
+
134
188
/**
135
189
* Metadata about the run.
136
190
*
@@ -196,4 +250,20 @@ public Optional<RunMetadata> getMetadata() {
196
250
private boolean closed ;
197
251
198
252
private static final Logger logger = Logger .getLogger (Result .class .getName ());
253
+
254
+ private String buildInvalidTensorTypeExceptionMessage (
255
+ Tensor tensor , String tensorName , Class <? extends TType > requestedType ) {
256
+ String actualTypeName =
257
+ tensor instanceof TType
258
+ ? ((TType ) tensor ).type ().getSimpleName ()
259
+ : tensor .getClass ().getName ();
260
+ throw new IllegalStateException (
261
+ "Tensor \" "
262
+ + tensorName
263
+ + "\" of type \" "
264
+ + actualTypeName
265
+ + "\" is not compatible with requested type \" "
266
+ + requestedType .getSimpleName ()
267
+ + "\" " );
268
+ }
199
269
}
0 commit comments