@@ -42,7 +42,10 @@ public static void main(String[] args) throws Exception {
42
42
Tensors .create (Paths .get (checkpointDir , "ckpt" ).toString ())) {
43
43
graph .importGraphDef (graphDef );
44
44
45
- // Initialize or restore
45
+ // Initialize or restore.
46
+ // The names of the tensors in the graph are printed out by the program
47
+ // that created the graph:
48
+ // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
46
49
if (checkpointExists ) {
47
50
sess .runner ().feed ("save/Const" , checkpointPrefix ).addTarget ("save/restore_all" ).run ();
48
51
} else {
@@ -60,14 +63,18 @@ public static void main(String[] args) throws Exception {
60
63
float in = r .nextFloat ();
61
64
try (Tensor <Float > input = Tensors .create (in );
62
65
Tensor <Float > target = Tensors .create (3 * in + 2 )) {
66
+ // Again the tensor names are from the program that created the graph.
67
+ // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py
63
68
sess .runner ().feed ("input" , input ).feed ("target" , target ).addTarget ("train" ).run ();
64
69
}
65
70
}
66
71
System .out .printf ("After %5d examples: " , i *NUM_EXAMPLES );
67
72
printVariables (sess );
68
73
}
69
74
70
- // Checkpoint
75
+ // Checkpoint.
76
+ // The feed and target name are from the program that created the graph.
77
+ // https://github.com/tensorflow/models/blob/master/samples/languages/java/training/model/create_graph.py.
71
78
sess .runner ().feed ("save/Const" , checkpointPrefix ).addTarget ("save/control_dependency" ).run ();
72
79
73
80
// Example of "inference" in the same graph:
0 commit comments