7
7
import org .tensorflow .framework .initializers .Glorot ;
8
8
import org .tensorflow .framework .initializers .VarianceScaling ;
9
9
import org .tensorflow .framework .utils .TestSession ;
10
+ import org .tensorflow .ndarray .FloatNdArray ;
10
11
import org .tensorflow .ndarray .Shape ;
11
12
import org .tensorflow .ndarray .buffer .DataBuffers ;
12
13
import org .tensorflow .op .Op ;
25
26
import org .tensorflow .types .family .TType ;
26
27
27
28
import java .util .ArrayList ;
29
+ import java .util .Arrays ;
28
30
import java .util .List ;
29
31
32
+ import static org .junit .jupiter .api .Assertions .assertArrayEquals ;
30
33
import static org .junit .jupiter .api .Assertions .assertEquals ;
31
34
32
35
/** Test cases for GradientDescent Optimizer */
@@ -125,6 +128,7 @@ public void testDeterminism() {
125
128
GraphDef def ;
126
129
String initName ;
127
130
String trainName ;
131
+ String lossName ;
128
132
129
133
String fcWeightName , fcBiasName , outputWeightName , outputBiasName ;
130
134
@@ -159,8 +163,9 @@ public void testDeterminism() {
159
163
Mean <TFloat32 > loss =
160
164
tf .math .mean (
161
165
tf .nn .raw .softmaxCrossEntropyWithLogits (output , placeholder ).loss (), tf .constant (0 ));
166
+ lossName = loss .op ().name ();
162
167
163
- GradientDescent gd = new GradientDescent (g , 0.1f );
168
+ GradientDescent gd = new GradientDescent (g , 10.0f );
164
169
Op trainingOp = gd .minimize (loss );
165
170
trainName = trainingOp .op ().name ();
166
171
@@ -177,12 +182,14 @@ public void testDeterminism() {
177
182
-14.0f , -15.0f , 0.16f , 0.17f , 0.18f , 1.9f , 0.2f
178
183
};
179
184
TFloat32 dataTensor = TFloat32 .tensorOf (Shape .of (1 , 20 ), DataBuffers .of (data ));
180
- float [] target = new float [] {0.0f , 1.0f };
185
+ float [] target = new float [] {0.2f , 0.8f };
181
186
TFloat32 targetTensor = TFloat32 .tensorOf (Shape .of (1 , 2 ), DataBuffers .of (target ));
182
187
183
- int numRuns = 10 ;
188
+ int numRuns = 20 ;
184
189
List <List <Tensor >> initialized = new ArrayList <>(numRuns );
185
190
List <List <Tensor >> trained = new ArrayList <>(numRuns );
191
+ float [] initialLoss = new float [numRuns ];
192
+ float [] postTrainingLoss = new float [numRuns ];
186
193
187
194
for (int i = 0 ; i < numRuns ; i ++) {
188
195
try (Graph g = new Graph ();
@@ -197,12 +204,16 @@ public void testDeterminism() {
197
204
.fetch (outputWeightName )
198
205
.fetch (outputBiasName )
199
206
.run ());
207
+ System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
200
208
201
- s .runner ()
209
+ TFloat32 lossVal = ( TFloat32 ) s .runner ()
202
210
.addTarget (trainName )
203
211
.feed ("input" , dataTensor )
204
212
.feed ("output" , targetTensor )
205
- .run ();
213
+ .fetch (lossName )
214
+ .run ().get (0 );
215
+ initialLoss [i ] = lossVal .getFloat ();
216
+ lossVal .close ();
206
217
207
218
trained .add (
208
219
s .runner ()
@@ -211,10 +222,25 @@ public void testDeterminism() {
211
222
.fetch (outputWeightName )
212
223
.fetch (outputBiasName )
213
224
.run ());
225
+ System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
226
+ System .out .println ("Trained - " + ndArrToString ((TFloat32 )trained .get (i ).get (3 )));
227
+
228
+ lossVal = (TFloat32 ) s .runner ()
229
+ .addTarget (trainName )
230
+ .feed ("input" , dataTensor )
231
+ .feed ("output" , targetTensor )
232
+ .fetch (lossName )
233
+ .run ().get (0 );
234
+ postTrainingLoss [i ] = lossVal .getFloat ();
235
+ lossVal .close ();
214
236
}
215
237
}
216
238
217
239
for (int i = 1 ; i < numRuns ; i ++) {
240
+ assertEquals (initialLoss [0 ],initialLoss [i ]);
241
+ assertEquals (postTrainingLoss [0 ],postTrainingLoss [i ]);
242
+ // Because the weights are references not copies.
243
+ assertEquals (initialized .get (i ),trained .get (i ));
218
244
assertEquals (
219
245
initialized .get (0 ),
220
246
initialized .get (i ),
@@ -234,4 +260,10 @@ public void testDeterminism() {
234
260
}
235
261
}
236
262
}
263
+
264
+ private static String ndArrToString (FloatNdArray ndarray ) {
265
+ StringBuffer sb = new StringBuffer ();
266
+ ndarray .scalars ().forEachIndexed ((idx ,array ) -> sb .append (Arrays .toString (idx )).append (" = " ).append (array .getFloat ()).append ("\n " ));
267
+ return sb .toString ();
268
+ }
237
269
}
0 commit comments