Skip to content

Commit a85bcfb

Browse files
authored
Merge pull request tensorflow#116 from JimClarke5/Initializers1
Add initializers
2 parents fa6e6e1 + f0934ea commit a85bcfb

32 files changed

+3800
-63
lines changed

ndarray/src/main/java/org/tensorflow/ndarray/Shape.java

Lines changed: 123 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ public static Shape scalar() {
5252
/**
5353
* Create a Shape representing a scalar or an N-dimensional value.
5454
*
55-
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1),
56-
* with the provided size for each dimension. A -1 indicates that the size of the corresponding
57-
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created.
58-
* For example:
55+
* <p>Creates a Shape representing a scalar or an N-dimensional value (N being at least 1), with
56+
* the provided size for each dimension. A -1 indicates that the size of the corresponding
57+
* dimension is unknown. If no sizes are provided, a Shape representing a scalar is created. For
58+
* example:
5959
*
6060
* <pre>{@code
6161
* // A 2-element vector.
@@ -88,11 +88,11 @@ public static Shape of(long... dimensionSizes) {
8888
/**
8989
* Returns the total number of elements a Tensor with this Shape would have.
9090
*
91-
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true,
92-
* {@link Shape#UNKNOWN_SIZE} is returned.
91+
* <p>If {@link Shape#isUnknown()} is true or {@link Shape#hasUnknownDimension()} is true, {@link
92+
* Shape#UNKNOWN_SIZE} is returned.
9393
*
9494
* @return The total number of elements a Tensor with this shape would have if it can be
95-
* calculated, else {@link Shape#UNKNOWN_SIZE}.
95+
* calculated, else {@link Shape#UNKNOWN_SIZE}.
9696
*/
9797
public long size() {
9898
if (size == null) {
@@ -108,12 +108,11 @@ public long size() {
108108
* an unknown size, {@link Shape#UNKNOWN_SIZE} is returned.
109109
*
110110
* @param i the index of the dimension to get the size for. If this Shape has a known number of
111-
* dimensions, it must be &lt; {@link Shape#numDimensions()}. The index may be negative,
112-
* in which case the position is counted from the end of the shape. E.g.:
113-
* {@code size(-1)} returns the size of the last dimension, {@code size(-2)} the size of
114-
* the second to last dimension etc.
111+
* dimensions, it must be < {@link Shape#numDimensions()}. The index may be negative, in which
112+
* case the position is counted from the end of the shape. E.g.: {@code size(-1)} returns the
113+
* size of the last dimension, {@code size(-2)} the size of the second to last dimension etc.
115114
* @return The size of the dimension with the given index if known, {@link Shape#UNKNOWN_SIZE}
116-
* otherwise.
115+
* otherwise.
117116
*/
118117
public long size(int i) {
119118
if (dimensionSizes == null) {
@@ -167,8 +166,8 @@ public boolean isUnknown() {
167166
}
168167

169168
/**
170-
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not
171-
* change this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
169+
* Returns a defensive copy of the this Shape's axes. Changes to the returned array to not change
170+
* this Shape's state. Returns null if {@link Shape#isUnknown()} is true.
172171
*/
173172
public long[] asArray() {
174173
if (this.dimensionSizes == null) {
@@ -186,15 +185,16 @@ public int hashCode() {
186185
/**
187186
* Equals implementation for Shapes. Two Shapes are considered equal iff:
188187
*
188+
* <p>
189189
* <ul>
190-
* <li>the number of dimensions is defined and equal for both
191-
* <li>the size of each dimension is defined and equal for both
190+
* <li>the number of dimensions is defined and equal for both
191+
* <li>the size of each dimension is defined and equal for both
192192
* </ul>
193193
*
194194
* <p>If either Shape has unknown dimensions (even if they are the same in both) or if either
195-
* shape has an unknown number of dimensions (even if both return {@code true} for
196-
* {@link Shape#isUnknown()}), they are not considered equal! However, a shape will always
197-
* equal itself, even if it is unknown or contains unknown dimensions.
195+
* shape has an unknown number of dimensions (even if both return {@code true} for {@link
196+
* Shape#isUnknown()}), they are not considered equal! However, a shape will always equal itself,
197+
* even if it is unknown or contains unknown dimensions.
198198
*/
199199
@Override
200200
public boolean equals(Object obj) {
@@ -233,17 +233,17 @@ public Shape head() {
233233
}
234234

235235
/**
236-
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions
237-
* of this shape
236+
* Returns an n-dimensional Shape with the dimensions matching the first n dimensions of this
237+
* shape
238238
*
239-
* @param n the number of leading dimensions to get, must be &lt;= than {@link Shape#numDimensions()}
240-
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions
241-
* of this Shape
239+
* @param n the number of leading dimensions to get, must be <= than {@link Shape#numDimensions()}
240+
* @return an n-dimensional Shape with the first n dimensions matching the first n dimensions of
241+
* this Shape
242242
*/
243243
public Shape take(int n) {
244244
if (n > numDimensions()) {
245-
throw new ArrayIndexOutOfBoundsException("Cannot take " + n +
246-
" dimensions, shape has only " + numDimensions() + ".");
245+
throw new ArrayIndexOutOfBoundsException(
246+
"Cannot take " + n + " dimensions, shape has only " + numDimensions() + ".");
247247
}
248248
long[] newDimensions = new long[n];
249249
System.arraycopy(dimensionSizes, 0, newDimensions, 0, n);
@@ -257,18 +257,18 @@ public Shape tail() {
257257
}
258258

259259
/**
260-
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions
261-
* of this Shape.
260+
* Returns an n-dimensional Shape with the dimensions matching the last n dimensions of this
261+
* Shape.
262262
*
263-
* @param n the number of trailing dimensions to get, must be &lt;= than
264-
* {@link Shape#numDimensions()}
263+
* @param n the number of trailing dimensions to get, must be <= than {@link
264+
* Shape#numDimensions()}
265265
* @return an n-dimensional shape with the dimensions matching the last n dimensions of this
266-
* Shape, never null
266+
* Shape, never null
267267
*/
268268
public Shape takeLast(int n) {
269269
if (n > numDimensions()) {
270-
throw new ArrayIndexOutOfBoundsException("Cannot take last " + n +
271-
" dimensions, shape has only " + numDimensions() + ".");
270+
throw new ArrayIndexOutOfBoundsException(
271+
"Cannot take last " + n + " dimensions, shape has only " + numDimensions() + ".");
272272
}
273273
long[] newDimensions = new long[n];
274274
System.arraycopy(dimensionSizes, numDimensions() - n, newDimensions, 0, n);
@@ -280,8 +280,8 @@ public Shape takeLast(int n) {
280280
* {@link Shape#isUnknown()} must be {@code false}.
281281
*
282282
* @param firstDimension the dimension to prepend
283-
* @return a new shape with the given dimension first, followed by this Shape's dimensions,
284-
* never null
283+
* @return a new shape with the given dimension first, followed by this Shape's dimensions, never
284+
* null
285285
*/
286286
public Shape prepend(long firstDimension) {
287287
long[] newDimensions = new long[dimensionSizes.length + 1];
@@ -292,8 +292,8 @@ public Shape prepend(long firstDimension) {
292292
}
293293

294294
/**
295-
* Returns a new Shape, with a new last dimension added. In order for this call to succeed,
296-
* {@link Shape#isUnknown()} must be {@code false}.
295+
* Returns a new Shape, with a new last dimension added. In order for this call to succeed, {@link
296+
* Shape#isUnknown()} must be {@code false}.
297297
*
298298
* @param lastDimension the dimension to append
299299
* @return a new Shape with this Shape's dimensions followed by the given dimension, never null
@@ -307,38 +307,36 @@ public Shape append(long lastDimension) {
307307
}
308308

309309
/**
310-
* Returns a new Shape, with another Shape's dimensions prepended.
311-
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
312-
* E.g. {@code Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
310+
* Returns a new Shape, with another Shape's dimensions prepended. For both this Shape and the
311+
* other Shape, {@link Shape#isUnknown()} must return false. E.g. {@code
312+
* Shape.of(3,4).prepend(Shape.of(1,2)) => Shape.of(1,2,3,4) }
313313
*
314314
* @param other another Shape, must not be {@code null}, must not be unknown
315-
* @return A new Shape consisting of the given Shapes's dimensions followed by this Shape's
316-
* dimensions, never null
315+
* @return A new Shape consisting of the given Shape's dimensions followed by this Shape's
316+
* dimensions, never null
317317
*/
318318
public Shape prepend(Shape other) {
319319
long[] newDimensions = new long[other.dimensionSizes.length + dimensionSizes.length];
320-
System.arraycopy(other.dimensionSizes, 0,
321-
newDimensions, 0, other.dimensionSizes.length);
322-
System.arraycopy(dimensionSizes, 0,
323-
newDimensions, other.dimensionSizes.length, dimensionSizes.length);
320+
System.arraycopy(other.dimensionSizes, 0, newDimensions, 0, other.dimensionSizes.length);
321+
System.arraycopy(
322+
dimensionSizes, 0, newDimensions, other.dimensionSizes.length, dimensionSizes.length);
324323
return Shape.of(newDimensions);
325324
}
326325

327326
/**
328-
* Returns a new Shape, with another Shapes' dimensions appended.
329-
* For both this Shape and the other Shape, {@link Shape#isUnknown()} must return false.
330-
* e.g. {@code Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
327+
* Returns a new Shape, with another Shapes' dimensions appended. For both this Shape and the
328+
* other Shape, {@link Shape#isUnknown()} must return false. E.g. @code
329+
* Shape.of(3,4).append(Shape.of(1,2)) => Shape.of(3,4,1,2) }
331330
*
332331
* @param other another Shape, must not be {@code null}, must not be unknown
333-
* @return A new Shape consisting of this Shapes's dimensions followed by the given Shape's
334-
* dimensions
332+
* @return A new Shape consisting of this Shape's dimensions followed by the given Shape's
333+
* dimensions
335334
*/
336335
public Shape append(Shape other) {
337336
long[] newDimensions = new long[dimensionSizes.length + other.dimensionSizes.length];
338-
System.arraycopy(dimensionSizes, 0,
339-
newDimensions, 0, dimensionSizes.length);
340-
System.arraycopy(other.dimensionSizes, 0,
341-
newDimensions, dimensionSizes.length, other.dimensionSizes.length);
337+
System.arraycopy(dimensionSizes, 0, newDimensions, 0, dimensionSizes.length);
338+
System.arraycopy(
339+
other.dimensionSizes, 0, newDimensions, dimensionSizes.length, other.dimensionSizes.length);
342340
return Shape.of(newDimensions);
343341
}
344342

@@ -355,4 +353,74 @@ private static long computeSize(long[] dimensionSizes) {
355353
}
356354
return computedSize;
357355
}
356+
357+
/**
358+
* Determines whether another shape is compatible with this one.
359+
*
360+
* <p>
361+
*
362+
* <p>Two possibly-partially-defined shapes are compatible if there exists a fully-defined shape
363+
* that both shapes can represent. Thus, compatibility allows the shape inference code to reason
364+
* about partially-defined shapes. For example:
365+
*
366+
* <ul>
367+
* <li><code>Shape.unknown()</code> is compatible with all shapes.
368+
* <li><code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> is compatible with all two-dimensional
369+
* shapes, such as <code>Shape(32, 784)</code>, and also <code>Shape.unknown()</code>. It is
370+
* not compatible with, for example, <code>Shape(UNKNOWN_SIZE)</code> or <code>
371+
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE, UNKNOWN_SIZE)</code>.
372+
* <li><code>Shape(32, UNKNOWN_SIZE)</code> is compatible with all two-dimensional shapes with
373+
* size 32 in the 0th dimension, and also <code>Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and
374+
* <code>Shape.unknown()</code>. It is not compatible with, for example, <code>Shape(32)
375+
* </code>, <code>Shape(32, UNKNOWN_SIZE, 1)</code> or <code>Shape(64, UNKNOWN_SIZE)</code>.
376+
* <li><code>Shape(32, 784)</code> is compatible with itself, and also <code>
377+
* Shape(32, UNKNOWN_SIZE)</code>, <code>Shape(UNKNOWN_SIZE, 784)</code>, <code>
378+
* Shape(UNKNOWN_SIZE, UNKNOWN_SIZE)</code> and <code>Shape.unknown()</code>. It is not
379+
* compatible with, for example, <code>Shape(32, 1, 784)</code> or <code>Shape(UNKNOWN_SIZE)
380+
* </code>.
381+
* </ul>
382+
*
383+
* <p>The compatibility relation is reflexive and symmetric, but not transitive. For example,
384+
* <code>Shape(32, 784)</code> is compatible with <code>Shape.unknown()</code>, and <code>
385+
* Shape.unknown()</code> is compatible with <code>Shape(4, 4)</code>, but <code>Shape(32, 784)
386+
* </code> is not compatible with <code>Shape(4, 4)</code>.
387+
*
388+
* <p>Compatibility is not the same as broadcasting. Compatible shapes must have the same number
389+
* of dimensions and for each dimension pair, one dimension has to equal the other dimensions or
390+
* at least one of the dimensions in the pair has to be UNKNOWN_SIZE.
391+
*
392+
* <p>Broadcasting allows different dimensions, but paired dimensions have to either be equal, or
393+
* one dimension must be 1. If one shape has less dimensions than another shape, the smaller shape
394+
* is "stretched" with dimensions of 1.
395+
*
396+
* @param shape The other shape
397+
* @return true, if the two shapes are compatible.
398+
*/
399+
public boolean isCompatibleWith(Shape shape) {
400+
if (!this.isUnknown() && !shape.isUnknown()) {
401+
if (numDimensions() != shape.numDimensions()) {
402+
return false;
403+
}
404+
for (int i = 0; i < numDimensions(); i++) {
405+
if (!isCompatible(size(i), shape.size(i))) {
406+
return false;
407+
}
408+
}
409+
}
410+
return true;
411+
}
412+
413+
/**
414+
* Test to see if two shape dimensions are compatible.
415+
*
416+
* <p>The dimensions are compatible if either dimension is <code>Shape.UNKNOWN_SIZE</code> or both
417+
* dimensions are equal
418+
*
419+
* @param dim the first dimension
420+
* @param otherDim the second dimension
421+
* @return true, if both dimensions are compatible
422+
*/
423+
public static boolean isCompatible(long dim, long otherDim) {
424+
return dim == Shape.UNKNOWN_SIZE || otherDim == Shape.UNKNOWN_SIZE || dim == otherDim;
425+
}
358426
}

ndarray/src/test/java/org/tensorflow/ndarray/ShapeTest.java

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
*/
1717
package org.tensorflow.ndarray;
1818

19-
import static org.junit.jupiter.api.Assertions.*;
20-
2119
import org.junit.jupiter.api.Test;
2220

21+
import static org.junit.jupiter.api.Assertions.*;
22+
2323
public class ShapeTest {
2424

2525
@Test
@@ -135,4 +135,36 @@ public void testShapeModification() {
135135
internalShape[0] = 42L;
136136
assertEquals(2L, one.size(0));
137137
}
138+
139+
@Test
140+
public void testShapeCompatible() {
141+
Shape a = Shape.unknown();
142+
Shape b = Shape.of(2, 2);
143+
assertTrue(a.isCompatibleWith(b));
144+
assertTrue(b.isCompatibleWith(a));
145+
146+
a = Shape.of(2, 2);
147+
assertTrue(a.isCompatibleWith(b));
148+
assertTrue(b.isCompatibleWith(a));
149+
150+
a = Shape.of(2, -1);
151+
assertTrue(a.isCompatibleWith(b));
152+
assertTrue(b.isCompatibleWith(a));
153+
154+
a = Shape.of(-1, 2);
155+
assertTrue(a.isCompatibleWith(b));
156+
assertTrue(b.isCompatibleWith(a));
157+
158+
a = Shape.of(-1, -1);
159+
assertTrue(a.isCompatibleWith(b));
160+
assertTrue(b.isCompatibleWith(a));
161+
162+
a = Shape.of(1, 2);
163+
assertFalse(a.isCompatibleWith(b));
164+
assertFalse(b.isCompatibleWith(a));
165+
166+
a = Shape.of(1, 2, 3);
167+
assertFalse(a.isCompatibleWith(b));
168+
assertFalse(b.isCompatibleWith(a));
169+
}
138170
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
=======================================================================*/
15+
package org.tensorflow.framework.initializers;
16+
17+
import org.tensorflow.op.Ops;
18+
import org.tensorflow.types.family.TType;
19+
20+
/** Abstract base class for all Initializers */
21+
public abstract class BaseInitializer<T extends TType> implements Initializer<T> {
22+
23+
protected final Ops tf;
24+
25+
/**
26+
* Creates an Initializer
27+
*
28+
* @param tf the TensorFlow Ops
29+
*/
30+
protected BaseInitializer(Ops tf) {
31+
this.tf = tf;
32+
}
33+
34+
/**
35+
* Gets the TensorFlow Ops
36+
*
37+
* @return the TensorFlow Ops
38+
*/
39+
public Ops getTF() {
40+
return tf;
41+
}
42+
}

0 commit comments

Comments
 (0)