Skip to content

Commit c0fc351

Browse files
committed
More work on the GradientDescentTest.
1 parent 5dd4b30 commit c0fc351

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import org.tensorflow.framework.initializers.Glorot;
88
import org.tensorflow.framework.initializers.VarianceScaling;
99
import org.tensorflow.framework.utils.TestSession;
10+
import org.tensorflow.ndarray.FloatNdArray;
1011
import org.tensorflow.ndarray.Shape;
1112
import org.tensorflow.ndarray.buffer.DataBuffers;
1213
import org.tensorflow.op.Op;
@@ -25,8 +26,10 @@
2526
import org.tensorflow.types.family.TType;
2627

2728
import java.util.ArrayList;
29+
import java.util.Arrays;
2830
import java.util.List;
2931

32+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
3033
import static org.junit.jupiter.api.Assertions.assertEquals;
3134

3235
/** Test cases for GradientDescent Optimizer */
@@ -125,6 +128,7 @@ public void testDeterminism() {
125128
GraphDef def;
126129
String initName;
127130
String trainName;
131+
String lossName;
128132

129133
String fcWeightName, fcBiasName, outputWeightName, outputBiasName;
130134

@@ -159,8 +163,9 @@ public void testDeterminism() {
159163
Mean<TFloat32> loss =
160164
tf.math.mean(
161165
tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0));
166+
lossName = loss.op().name();
162167

163-
GradientDescent gd = new GradientDescent(g, 0.1f);
168+
GradientDescent gd = new GradientDescent(g, 10.0f);
164169
Op trainingOp = gd.minimize(loss);
165170
trainName = trainingOp.op().name();
166171

@@ -177,12 +182,14 @@ public void testDeterminism() {
177182
-14.0f, -15.0f, 0.16f, 0.17f, 0.18f, 1.9f, 0.2f
178183
};
179184
TFloat32 dataTensor = TFloat32.tensorOf(Shape.of(1, 20), DataBuffers.of(data));
180-
float[] target = new float[] {0.0f, 1.0f};
185+
float[] target = new float[] {0.2f, 0.8f};
181186
TFloat32 targetTensor = TFloat32.tensorOf(Shape.of(1, 2), DataBuffers.of(target));
182187

183-
int numRuns = 10;
188+
int numRuns = 20;
184189
List<List<Tensor>> initialized = new ArrayList<>(numRuns);
185190
List<List<Tensor>> trained = new ArrayList<>(numRuns);
191+
float[] initialLoss = new float[numRuns];
192+
float[] postTrainingLoss = new float[numRuns];
186193

187194
for (int i = 0; i < numRuns; i++) {
188195
try (Graph g = new Graph();
@@ -197,12 +204,16 @@ public void testDeterminism() {
197204
.fetch(outputWeightName)
198205
.fetch(outputBiasName)
199206
.run());
207+
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3)));
200208

201-
s.runner()
209+
TFloat32 lossVal = (TFloat32) s.runner()
202210
.addTarget(trainName)
203211
.feed("input", dataTensor)
204212
.feed("output", targetTensor)
205-
.run();
213+
.fetch(lossName)
214+
.run().get(0);
215+
initialLoss[i] = lossVal.getFloat();
216+
lossVal.close();
206217

207218
trained.add(
208219
s.runner()
@@ -211,10 +222,25 @@ public void testDeterminism() {
211222
.fetch(outputWeightName)
212223
.fetch(outputBiasName)
213224
.run());
225+
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3)));
226+
System.out.println("Trained - " + ndArrToString((TFloat32)trained.get(i).get(3)));
227+
228+
lossVal = (TFloat32) s.runner()
229+
.addTarget(trainName)
230+
.feed("input", dataTensor)
231+
.feed("output", targetTensor)
232+
.fetch(lossName)
233+
.run().get(0);
234+
postTrainingLoss[i] = lossVal.getFloat();
235+
lossVal.close();
214236
}
215237
}
216238

217239
for (int i = 1; i < numRuns; i++) {
240+
assertEquals(initialLoss[0],initialLoss[i]);
241+
assertEquals(postTrainingLoss[0],postTrainingLoss[i]);
242+
// Because the weights are references not copies.
243+
assertEquals(initialized.get(i),trained.get(i));
218244
assertEquals(
219245
initialized.get(0),
220246
initialized.get(i),
@@ -234,4 +260,10 @@ public void testDeterminism() {
234260
}
235261
}
236262
}
263+
264+
private static String ndArrToString(FloatNdArray ndarray) {
265+
StringBuffer sb = new StringBuffer();
266+
ndarray.scalars().forEachIndexed((idx,array) -> sb.append(Arrays.toString(idx)).append(" = ").append(array.getFloat()).append("\n"));
267+
return sb.toString();
268+
}
237269
}

0 commit comments

Comments
 (0)