Skip to content

Commit 5dd4b30

Browse files
committed
Setting all the optimizers to have useLocking = True, like Keras. Adding a determinism test that's currently failing.
1 parent 63458d2 commit 5dd4b30

File tree

9 files changed

+166
-9
lines changed

9 files changed

+166
-9
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.tensorflow.Output;
2121
import org.tensorflow.op.Op;
2222
import org.tensorflow.op.core.Variable;
23+
import org.tensorflow.op.train.ApplyAdadelta;
2324
import org.tensorflow.types.family.TType;
2425

2526
import java.util.List;
@@ -160,7 +161,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
160161
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
161162
tf.dtypes.cast(tf.constant(rho), gradient.type()),
162163
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
163-
gradient);
164+
gradient,
165+
ApplyAdadelta.useLocking(true));
164166
}
165167

166168
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.tensorflow.Operand;
2020
import org.tensorflow.Output;
2121
import org.tensorflow.op.Op;
22+
import org.tensorflow.op.train.ApplyAdagrad;
2223
import org.tensorflow.op.core.Variable;
2324
import org.tensorflow.types.family.TType;
2425

@@ -42,6 +43,9 @@ public class AdaGrad extends Optimizer {
4243
public static final float LEARNING_RATE_DEFAULT = 0.001f;
4344
public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f;
4445

46+
private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{
47+
ApplyAdagrad.updateSlots(true),ApplyAdagrad.useLocking(true)};
48+
4549
private final float learningRate;
4650

4751
private final float initialAccumulatorValue;
@@ -140,7 +144,7 @@ private <T extends TType> void createAdaGradSlot(Output<T> v) {
140144
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
141145
Variable<T> slot = getSlot(variable, ACCUMULATOR).get();
142146
return tf.train.applyAdagrad(
143-
variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient);
147+
variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, opts);
144148
}
145149

146150
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.tensorflow.op.Op;
2323
import org.tensorflow.op.core.Assign;
2424
import org.tensorflow.op.core.Variable;
25+
import org.tensorflow.op.train.ApplyAdagradDa;
2526
import org.tensorflow.types.TInt64;
2627
import org.tensorflow.types.family.TType;
2728

@@ -219,7 +220,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
219220
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
220221
tf.dtypes.cast(tf.constant(l1Strength), gradient.type()),
221222
tf.dtypes.cast(tf.constant(l2Strength), gradient.type()),
222-
globalStep);
223+
globalStep,
224+
ApplyAdagradDa.useLocking(true));
223225
}
224226

225227
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.tensorflow.op.core.Assign;
2727
import org.tensorflow.op.core.Constant;
2828
import org.tensorflow.op.core.Variable;
29+
import org.tensorflow.op.train.ApplyAdam;
2930
import org.tensorflow.types.TFloat32;
3031
import org.tensorflow.types.family.TType;
3132

@@ -237,7 +238,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
237238
tf.dtypes.cast(betaOneConst, gradient.type()),
238239
tf.dtypes.cast(betaTwoConst, gradient.type()),
239240
tf.dtypes.cast(epsilonConst, gradient.type()),
240-
gradient);
241+
gradient,
242+
ApplyAdam.useLocking(true));
241243
}
242244

243245
/**

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
170170
tf.dtypes.cast(betaOneConst, gradient.type()),
171171
tf.dtypes.cast(betaTwoConst, gradient.type()),
172172
tf.dtypes.cast(epsilonConst, gradient.type()),
173-
gradient);
173+
gradient,
174+
ApplyAdaMax.useLocking(true));
174175
}
175176

176177
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import org.tensorflow.Graph;
1919
import org.tensorflow.Output;
2020
import org.tensorflow.op.Op;
21+
import org.tensorflow.op.train.ApplyGradientDescent;
2122
import org.tensorflow.types.family.TType;
2223

2324
/**
@@ -66,7 +67,10 @@ public GradientDescent(Graph graph, String name, float learningRate) {
6667
@Override
6768
protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable) {
6869
return tf.train.applyGradientDescent(
69-
variable, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient);
70+
variable,
71+
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
72+
gradient,
73+
ApplyGradientDescent.useLocking(true));
7074
}
7175

7276
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
139139
tf.dtypes.cast(tf.constant(learningRate), gradient.type()),
140140
gradient,
141141
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
142-
ApplyMomentum.useNesterov(useNesterov));
142+
ApplyMomentum.useNesterov(useNesterov),
143+
ApplyMomentum.useLocking(true));
143144
}
144145

145146
/** {@inheritDoc} */

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import org.tensorflow.Output;
2121
import org.tensorflow.op.Op;
2222
import org.tensorflow.op.core.Variable;
23+
import org.tensorflow.op.train.ApplyCenteredRmsProp;
24+
import org.tensorflow.op.train.ApplyRmsProp;
2325
import org.tensorflow.types.family.TType;
2426

2527
import java.util.List;
@@ -202,7 +204,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
202204
tf.dtypes.cast(tf.constant(decay), gradient.type()),
203205
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
204206
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
205-
gradient);
207+
gradient,
208+
ApplyCenteredRmsProp.useLocking(true));
206209
}
207210
return tf.train.applyRmsProp(
208211
variable,
@@ -212,7 +215,8 @@ protected <T extends TType> Op applyDense(Output<T> gradient, Output<T> variable
212215
tf.dtypes.cast(tf.constant(decay), gradient.type()),
213216
tf.dtypes.cast(tf.constant(momentum), gradient.type()),
214217
tf.dtypes.cast(tf.constant(epsilon), gradient.type()),
215-
gradient);
218+
gradient,
219+
ApplyRmsProp.useLocking(true));
216220
}
217221

218222
/** {@inheritDoc} */

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,25 @@
22

33
import org.junit.jupiter.api.*;
44
import org.tensorflow.Graph;
5+
import org.tensorflow.Session;
6+
import org.tensorflow.Tensor;
7+
import org.tensorflow.framework.initializers.Glorot;
8+
import org.tensorflow.framework.initializers.VarianceScaling;
59
import org.tensorflow.framework.utils.TestSession;
610
import org.tensorflow.ndarray.Shape;
11+
import org.tensorflow.ndarray.buffer.DataBuffers;
712
import org.tensorflow.op.Op;
813
import org.tensorflow.op.Ops;
914
import org.tensorflow.op.core.Assign;
1015
import org.tensorflow.op.core.Constant;
16+
import org.tensorflow.op.core.Init;
17+
import org.tensorflow.op.core.Placeholder;
1118
import org.tensorflow.op.core.Variable;
19+
import org.tensorflow.op.math.Add;
20+
import org.tensorflow.op.math.Mean;
21+
import org.tensorflow.op.nn.Relu;
22+
import org.tensorflow.proto.framework.ConfigProto;
23+
import org.tensorflow.proto.framework.GraphDef;
1224
import org.tensorflow.types.TFloat32;
1325
import org.tensorflow.types.family.TType;
1426

@@ -97,4 +109,129 @@ public void testBasic() {
97109
session.evaluate(expectedVar1, var1);
98110
}
99111
}
112+
113+
// This test fails due to initialization and gradient issues. It should not, but it seems to be a
114+
// problem
115+
// in TF-core.
116+
@Disabled
117+
@Test
118+
public void testDeterminism() {
119+
ConfigProto config =
120+
ConfigProto.newBuilder()
121+
.setIntraOpParallelismThreads(1)
122+
.setInterOpParallelismThreads(1)
123+
.build();
124+
125+
GraphDef def;
126+
String initName;
127+
String trainName;
128+
129+
String fcWeightName, fcBiasName, outputWeightName, outputBiasName;
130+
131+
try (Graph g = new Graph()) {
132+
Ops tf = Ops.create(g);
133+
134+
Glorot<TFloat32> initializer =
135+
new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L);
136+
// Inputs
137+
Placeholder<TFloat32> input =
138+
tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20)));
139+
140+
// Fully connected layer
141+
Variable<TFloat32> fcWeights =
142+
tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class));
143+
fcWeightName = fcWeights.op().name();
144+
Variable<TFloat32> fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f)));
145+
fcBiasName = fcBiases.op().name();
146+
Relu<TFloat32> relu = tf.nn.relu(tf.math.add(tf.linalg.matMul(input, fcWeights), fcBiases));
147+
148+
// Output layer
149+
Variable<TFloat32> outputWeights =
150+
tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class));
151+
outputWeightName = outputWeights.op().name();
152+
Variable<TFloat32> outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f)));
153+
outputBiasName = outputBiases.op().name();
154+
Add<TFloat32> output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases);
155+
156+
// Loss
157+
Placeholder<TFloat32> placeholder =
158+
tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2)));
159+
Mean<TFloat32> loss =
160+
tf.math.mean(
161+
tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0));
162+
163+
GradientDescent gd = new GradientDescent(g, 0.1f);
164+
Op trainingOp = gd.minimize(loss);
165+
trainName = trainingOp.op().name();
166+
167+
// Create the init op
168+
Init init = tf.init();
169+
initName = init.op().name();
170+
171+
def = g.toGraphDef();
172+
}
173+
174+
float[] data =
175+
new float[] {
176+
1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, -8.0f, -9.0f, 10.0f, 11.0f, 12.0f, 13.0f,
177+
-14.0f, -15.0f, 0.16f, 0.17f, 0.18f, 1.9f, 0.2f
178+
};
179+
TFloat32 dataTensor = TFloat32.tensorOf(Shape.of(1, 20), DataBuffers.of(data));
180+
float[] target = new float[] {0.0f, 1.0f};
181+
TFloat32 targetTensor = TFloat32.tensorOf(Shape.of(1, 2), DataBuffers.of(target));
182+
183+
int numRuns = 10;
184+
List<List<Tensor>> initialized = new ArrayList<>(numRuns);
185+
List<List<Tensor>> trained = new ArrayList<>(numRuns);
186+
187+
for (int i = 0; i < numRuns; i++) {
188+
try (Graph g = new Graph();
189+
Session s = new Session(g, config)) {
190+
g.importGraphDef(def);
191+
s.run(initName);
192+
193+
initialized.add(
194+
s.runner()
195+
.fetch(fcWeightName)
196+
.fetch(fcBiasName)
197+
.fetch(outputWeightName)
198+
.fetch(outputBiasName)
199+
.run());
200+
201+
s.runner()
202+
.addTarget(trainName)
203+
.feed("input", dataTensor)
204+
.feed("output", targetTensor)
205+
.run();
206+
207+
trained.add(
208+
s.runner()
209+
.fetch(fcWeightName)
210+
.fetch(fcBiasName)
211+
.fetch(outputWeightName)
212+
.fetch(outputBiasName)
213+
.run());
214+
}
215+
}
216+
217+
for (int i = 1; i < numRuns; i++) {
218+
assertEquals(
219+
initialized.get(0),
220+
initialized.get(i),
221+
"Variables not initialized identically (0," + i + ")");
222+
assertEquals(
223+
trained.get(0), trained.get(i), "Variables not trained identically (0," + i + ")");
224+
}
225+
226+
for (List<Tensor> curInit : initialized) {
227+
for (Tensor t : curInit) {
228+
t.close();
229+
}
230+
}
231+
for (List<Tensor> curTrained : trained) {
232+
for (Tensor t : curTrained) {
233+
t.close();
234+
}
235+
}
236+
}
100237
}

0 commit comments

Comments
 (0)