@@ -52,10 +52,10 @@ public static Shape scalar() {
52
52
/**
53
53
* Create a Shape representing a scalar or an N-dimensional value.
54
54
*
55
- * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
56
- * with the provided size for each dimension. A -1 indicates that the size of the corresponding
57
- * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
58
- * For example:
55
+ * <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
56
+ * the provided size for each dimension. A -1 indicates that the size of the corresponding
57
+ * dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
58
+ * example:
59
59
*
60
60
* <pre>{@code
61
61
* // A 2-element vector.
@@ -88,11 +88,11 @@ public static Shape of(long... dimensionSizes) {
88
88
/**
89
89
* Returns the total number of elements a Tensor with this Shape would have.
90
90
*
91
- * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
92
- * {@link Shape#UNKNOWN_SIZE} is returned.
91
+ * <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
92
+ * Shape#UNKNOWN_SIZE} is returned.
93
93
*
94
94
* @return The total number of elements a Tensor with this shape would have if it can be
95
- * calculated, else {@link Shape#UNKNOWN_SIZE}.
95
+ * calculated, else {@link Shape#UNKNOWN_SIZE}.
96
96
*/
97
97
public long size () {
98
98
if (size == null ) {
@@ -108,12 +108,11 @@ public long size() {
108
108
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109
109
*
110
110
* @param i the index of the dimension to get the size for. If this Shape has a known number of
111
- * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative,
112
- * in which case the position is counted from the end of the shape. E.g.:
113
- * {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
114
- * the second to last dimension etc.
111
+ * dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
112
+ * case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113
+ * size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
115
114
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
116
- * otherwise.
115
+ * otherwise.
117
116
*/
118
117
public long size (int i ) {
119
118
if (dimensionSizes == null ) {
@@ -167,8 +166,8 @@ public boolean isUnknown() {
167
166
}
168
167
169
168
/**
170
- * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
171
- * change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
169
+ * Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
170
+ * this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
172
171
*/
173
172
public long [] asArray () {
174
173
if (this .dimensionSizes == null ) {
@@ -186,15 +185,16 @@ public int hashCode() {
186
185
/**
187
186
* Equals implementation for Shapes. Two Shapes are considered equal iff:
188
187
*
188
+ * <p>
189
189
* <ul>
190
- * <li>the number of dimensions is defined and equal for both
191
- * <li>the size of each dimension is defined and equal for both
190
+ * <li>the number of dimensions is defined and equal for both
191
+ * <li>the size of each dimension is defined and equal for both
192
192
* </ul>
193
193
*
194
194
* <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
195
- * shape has an unknown number of dimensions (even if both return {@code true} for
196
- * {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
197
- * equal itself, even if it is unknown or contains unknown dimensions.
195
+ * shape has an unknown number of dimensions (even if both return {@code true} for {@link
196
+ * Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
197
+ * even if it is unknown or contains unknown dimensions.
198
198
*/
199
199
@ Override
200
200
public boolean equals (Object obj ) {
@@ -233,17 +233,17 @@ public Shape head() {
233
233
}
234
234
235
235
/**
236
- * Returns an n-dimensional Shape with the dimensions matching the first n dimensions
237
- * of this shape
236
+ * Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237
+ * shape
238
238
*
239
- * @param n the number of leading dimensions to get, must be < = than {@link Shape#numDimensions()}
240
- * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
241
- * of this Shape
239
+ * @param n the number of leading dimensions to get, must be < = than {@link Shape#numDimensions()}
240
+ * @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241
+ * this Shape
242
242
*/
243
243
public Shape take (int n ) {
244
244
if (n > numDimensions ()) {
245
- throw new ArrayIndexOutOfBoundsException ("Cannot take " + n +
246
- " dimensions, shape has only " + numDimensions () + "." );
245
+ throw new ArrayIndexOutOfBoundsException (
246
+ "Cannot take " + n + " dimensions, shape has only " + numDimensions () + "." );
247
247
}
248
248
long [] newDimensions = new long [n ];
249
249
System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , n );
@@ -257,18 +257,18 @@ public Shape tail() {
257
257
}
258
258
259
259
/**
260
- * Returns an n-dimensional Shape with the dimensions matching the last n dimensions
261
- * of this Shape.
260
+ * Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
261
+ * Shape.
262
262
*
263
- * @param n the number of trailing dimensions to get, must be < = than
264
- * {@link Shape#numDimensions()}
263
+ * @param n the number of trailing dimensions to get, must be < = than {@link
264
+ * Shape#numDimensions()}
265
265
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this
266
- * Shape, never null
266
+ * Shape, never null
267
267
*/
268
268
public Shape takeLast (int n ) {
269
269
if (n > numDimensions ()) {
270
- throw new ArrayIndexOutOfBoundsException ("Cannot take last " + n +
271
- " dimensions, shape has only " + numDimensions () + "." );
270
+ throw new ArrayIndexOutOfBoundsException (
271
+ "Cannot take last " + n + " dimensions, shape has only " + numDimensions () + "." );
272
272
}
273
273
long [] newDimensions = new long [n ];
274
274
System .arraycopy (dimensionSizes , numDimensions () - n , newDimensions , 0 , n );
@@ -280,8 +280,8 @@ public Shape takeLast(int n) {
280
280
* {@link Shape#isUnknown()} must be {@code false}.
281
281
*
282
282
* @param firstDimension the dimension to prepend
283
- * @return a new shape with the given dimension first, followed by this Shape's dimensions,
284
- * never null
283
+ * @return a new shape with the given dimension first, followed by this Shape's dimensions, never
284
+ * null
285
285
*/
286
286
public Shape prepend (long firstDimension ) {
287
287
long [] newDimensions = new long [dimensionSizes .length + 1 ];
@@ -292,8 +292,8 @@ public Shape prepend(long firstDimension) {
292
292
}
293
293
294
294
/**
295
- * Returns a new Shape, with a new last dimension added. In order for this call to succeed,
296
- * {@link Shape#isUnknown()} must be {@code false}.
295
+ * Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
296
+ * Shape#isUnknown()} must be {@code false}.
297
297
*
298
298
* @param lastDimension the dimension to append
299
299
* @return a new Shape with this Shape's dimensions followed by the given dimension, never null
@@ -307,38 +307,36 @@ public Shape append(long lastDimension) {
307
307
}
308
308
309
309
/**
310
- * Returns a new Shape, with another Shape's dimensions prepended.
311
- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
312
- * E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
310
+ * Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
311
+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
312
+ * Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
313
313
*
314
314
* @param other another Shape, must not be {@code null}, must not be unknown
315
- * @return A new Shape consisting of the given Shapes 's dimensions followed by this Shape's
316
- * dimensions, never null
315
+ * @return A new Shape consisting of the given Shape 's dimensions followed by this Shape's
316
+ * dimensions, never null
317
317
*/
318
318
public Shape prepend (Shape other ) {
319
319
long [] newDimensions = new long [other .dimensionSizes .length + dimensionSizes .length ];
320
- System .arraycopy (other .dimensionSizes , 0 ,
321
- newDimensions , 0 , other .dimensionSizes .length );
322
- System .arraycopy (dimensionSizes , 0 ,
323
- newDimensions , other .dimensionSizes .length , dimensionSizes .length );
320
+ System .arraycopy (other .dimensionSizes , 0 , newDimensions , 0 , other .dimensionSizes .length );
321
+ System .arraycopy (
322
+ dimensionSizes , 0 , newDimensions , other .dimensionSizes .length , dimensionSizes .length );
324
323
return Shape .of (newDimensions );
325
324
}
326
325
327
326
/**
328
- * Returns a new Shape, with another Shapes' dimensions appended.
329
- * For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
330
- * e.g. {@code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
327
+ * Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
328
+ * other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
329
+ * Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
331
330
*
332
331
* @param other another Shape, must not be {@code null}, must not be unknown
333
- * @return A new Shape consisting of this Shapes 's dimensions followed by the given Shape's
334
- * dimensions
332
+ * @return A new Shape consisting of this Shape 's dimensions followed by the given Shape's
333
+ * dimensions
335
334
*/
336
335
public Shape append (Shape other ) {
337
336
long [] newDimensions = new long [dimensionSizes .length + other .dimensionSizes .length ];
338
- System .arraycopy (dimensionSizes , 0 ,
339
- newDimensions , 0 , dimensionSizes .length );
340
- System .arraycopy (other .dimensionSizes , 0 ,
341
- newDimensions , dimensionSizes .length , other .dimensionSizes .length );
337
+ System .arraycopy (dimensionSizes , 0 , newDimensions , 0 , dimensionSizes .length );
338
+ System .arraycopy (
339
+ other .dimensionSizes , 0 , newDimensions , dimensionSizes .length , other .dimensionSizes .length );
342
340
return Shape .of (newDimensions );
343
341
}
344
342
@@ -355,4 +353,74 @@ private static long computeSize(long[] dimensionSizes) {
355
353
}
356
354
return computedSize ;
357
355
}
356
+
357
+ /**
358
+ * Determines whether another shape is compatible with this one.
359
+ *
360
+ * <p>
361
+ *
362
+ * <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
363
+ * that both shapes can represent. Thus, compatibility allows the shape inference code to reason
364
+ * about partially-defined shapes. For example:
365
+ *
366
+ * <ul>
367
+ * <li><code>Shape.unknown()</code> is compatible with all shapes.
368
+ * <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
369
+ * shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
370
+ * not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
371
+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
372
+ * <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
373
+ * size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
374
+ * <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
375
+ * </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
376
+ * <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
377
+ * Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
378
+ * Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
379
+ * compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
380
+ * </code>.
381
+ * </ul>
382
+ *
383
+ * <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
384
+ * <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
385
+ * Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
386
+ * </code> is not compatible with <code>Shape(4, 4)</code>.
387
+ *
388
+ * <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
389
+ * of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
390
+ * at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
391
+ *
392
+ * <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
393
+ * one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
394
+ * is "stretched" with dimensions of 1.
395
+ *
396
+ * @param shape The other shape
397
+ * @return true, if the two shapes are compatible.
398
+ */
399
+ public boolean isCompatibleWith (Shape shape ) {
400
+ if (!this .isUnknown () && !shape .isUnknown ()) {
401
+ if (numDimensions () != shape .numDimensions ()) {
402
+ return false ;
403
+ }
404
+ for (int i = 0 ; i < numDimensions (); i ++) {
405
+ if (!isCompatible (size (i ), shape .size (i ))) {
406
+ return false ;
407
+ }
408
+ }
409
+ }
410
+ return true ;
411
+ }
412
+
413
+ /**
414
+ * Test to see if two shape dimensions are compatible.
415
+ *
416
+ * <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
417
+ * dimensions are equal
418
+ *
419
+ * @param dim the first dimension
420
+ * @param otherDim the second dimension
421
+ * @return true, if both dimensions are compatible
422
+ */
423
+ public static boolean isCompatible (long dim , long otherDim ) {
424
+ return dim == Shape .UNKNOWN_SIZE || otherDim == Shape .UNKNOWN_SIZE || dim == otherDim ;
425
+ }
358
426
}
0 commit comments