Skip to content

Commit cdd0298

Browse files
authored
Adds Session.restore to allow loading of checkpoints saved by Session.save (#225)
* Initial support for restoring a saved session. * Updating surefire version to 3.0.0-M5 to get around a stdout test issue. * Updates to ensure the saverdef doesn't mutate the graph twice. * Matching the filename placeholder in saverdef. * Unpicking the nested placeholders, it makes various savers throw exceptions.
1 parent 858bd19 commit cdd0298

File tree

4 files changed

+84
-18
lines changed

4 files changed

+84
-18
lines changed

tensorflow-core/tensorflow-core-api/pom.xml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,7 +373,7 @@
373373
</plugin>
374374
<plugin>
375375
<artifactId>maven-surefire-plugin</artifactId>
376-
<version>2.22.0</version>
376+
<version>3.0.0-M5</version>
377377
<executions>
378378
<execution>
379379
<!--
@@ -389,6 +389,8 @@
389389
</execution>
390390
</executions>
391391
<configuration>
392+
<!-- Activate the use of TCP to transmit events to the plugin -->
393+
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
392394
<additionalClasspathElements>
393395
<additionalClasspathElement>${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar</additionalClasspathElement>
394396
<!-- Note: the following path is not accessible in deploying profile, so other libraries like

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import com.google.protobuf.InvalidProtocolBufferException;
3030
import java.util.ArrayList;
31+
import java.util.Arrays;
3132
import java.util.Collections;
3233
import java.util.Iterator;
3334
import java.util.List;
@@ -47,6 +48,7 @@
4748
import org.tensorflow.op.Op;
4849
import org.tensorflow.op.Ops;
4950
import org.tensorflow.op.core.Constant;
51+
import org.tensorflow.op.core.Identity;
5052
import org.tensorflow.op.core.NoOp;
5153
import org.tensorflow.op.core.Placeholder;
5254
import org.tensorflow.op.train.Restore;
@@ -439,15 +441,32 @@ public Output<?>[] whileLoop(
439441
* Return the {@link SaverDef} instance used to save the state of all variables present in
440442
* this graph.
441443
*
442-
* <p/>On the first call of this method, all nodes necessary to save and restore the state of the
443-
* variables are added to the graph. Consequently, any variables that are added to the graph after
444-
* this call could not be saved nor restored using this {@link SaverDef}.
444+
* <p/> The first time this method is called it builds the {@link SaverDef}. If this graph already
445+
* contains a "save/restore_all" operation then it is assumed to contain all necessary saving and
446+
* restoring operations. If that operation does not exist then the graph is mutated to add all
447+
* the nodes necessary to save and restore the state of the graph. Consequently, any variables
448+
* that are added to the graph after this call will not be saved nor restored using this
449+
* {@link SaverDef}.
445450
*
446451
* @return a {@link SaverDef} instance
447452
*/
448453
synchronized SaverDef saverDef() {
449454
if (saverDef == null) {
450-
saverDef = addVariableSaver(this);
455+
// Check to see if this graph has a restore operation
456+
if (operation("save/restore_all") == null) {
457+
// No saver, create one by mutating the graph
458+
saverDef = addVariableSaver(this);
459+
} else {
460+
// This graph already has saving/restoring operations,
461+
// regenerate SaverDef without mutating. The names mirror
462+
// the python implementation for compatibility.
463+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
464+
saverDef = SaverDef.newBuilder()
465+
.setFilenameTensorName("save/filename")
466+
.setSaveTensorName("save/control_dependency")
467+
.setRestoreOpName("save/restore_all")
468+
.build();
469+
}
451470
}
452471
return saverDef;
453472
}
@@ -798,13 +817,15 @@ private static SaverDef addVariableSaver(Graph graph) {
798817
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
799818
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);
800819

801-
Placeholder<TString> saveFilename = tf.placeholder(TString.class);
820+
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
802821
Save saveVariables = tf.train.save(
803822
saveFilename,
804823
varNamesTensor,
805824
varSlices,
806825
varOutputs
807826
);
827+
Identity<TString> id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables))
828+
.withName("control_dependency").identity(saveFilename);
808829
Restore restoreVariables = tf.train.restore(
809830
saveFilename,
810831
varNamesTensor,
@@ -815,11 +836,11 @@ private static SaverDef addVariableSaver(Graph graph) {
815836
for (int i = 0; i < varOutputs.size(); ++i) {
816837
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
817838
}
818-
NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
839+
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();
819840

820841
return SaverDef.newBuilder()
821842
.setFilenameTensorName(saveFilename.op().name())
822-
.setSaveTensorName(saveVariables.op().name())
843+
.setSaveTensorName(id.op().name())
823844
.setRestoreOpName(restoreAll.op().name())
824845
.build();
825846
}

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -512,16 +512,35 @@ public void runInit(){
512512
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
513513
*
514514
* <p>Note that this method might alter the underlying graph if it is the first time that one
515-
* of its session is saved, see {@link Graph#saverDef()} for more details.
515+
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
516516
*
517517
* @param prefix prefix to the variable files to save
518518
*/
519519
public void save(String prefix) {
520520
SaverDef saverDef = graph.saverDef();
521-
runner()
522-
.addTarget(saverDef.getSaveTensorName())
523-
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
524-
.run();
521+
runner().addTarget(saverDef.getSaveTensorName())
522+
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
523+
.run();
524+
}
525+
526+
/**
527+
* Restore the actual state of the variables of this session's graph.
528+
*
529+
* <p>{@code prefix} is the path where the files containing the variables state live,
530+
* followed by the filename prefix. For example, if {@code prefix} is set to
531+
* <i>mymodel/myvariables/variables</i>, then the files are loaded from
532+
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
533+
*
534+
* <p>Note that this method might alter the underlying graph if it is the first time that one
535+
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
536+
*
537+
* @param prefix prefix to restore from
538+
*/
539+
public void restore(String prefix) {
540+
SaverDef saverDef = graph.saverDef();
541+
runner().addTarget(saverDef.getRestoreOpName())
542+
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
543+
.run();
525544
}
526545

527546
/**

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020
import static org.junit.jupiter.api.Assertions.assertTrue;
2121
import static org.junit.jupiter.api.Assertions.fail;
2222

23+
import java.io.BufferedOutputStream;
24+
import java.io.File;
25+
import java.io.FileOutputStream;
2326
import java.io.IOException;
2427
import java.nio.file.Files;
2528
import java.nio.file.Path;
26-
import java.nio.file.Paths;
29+
import java.util.Comparator;
30+
2731
import org.junit.jupiter.api.Test;
2832
import org.tensorflow.op.Ops;
2933
import org.tensorflow.op.core.Init;
@@ -32,6 +36,7 @@
3236
import org.tensorflow.op.linalg.MatMul;
3337
import org.tensorflow.op.math.Add;
3438
import org.tensorflow.proto.framework.ConfigProto;
39+
import org.tensorflow.proto.framework.GraphDef;
3540
import org.tensorflow.proto.framework.RunOptions;
3641
import org.tensorflow.ndarray.Shape;
3742
import org.tensorflow.ndarray.NdArrays;
@@ -208,21 +213,40 @@ public void runInitByName() {
208213
}
209214

210215
@Test
211-
public void save() throws IOException {
212-
Path testFolder = Files.createTempDirectory("tf-session-save-test");
216+
public void saveAndRestore() throws IOException {
217+
Path testFolder = Files.createTempDirectory("tf-session-save-restore-test");
213218
try (Graph g = new Graph()) {
214219
Ops tf = Ops.create(g);
215-
Variable<TFloat32> x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
216-
Variable<TFloat32> y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
220+
Variable<TFloat32> x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
221+
Variable<TFloat32> y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
217222
Init init = tf.init();
218223

219224
try (Session s = new Session(g)) {
220225
s.run(init);
221226
s.save(testFolder.resolve("checkpoint").toString());
227+
GraphDef graphDef = g.toGraphDef();
228+
229+
try (Graph restoredGraph = new Graph()) {
230+
restoredGraph.importGraphDef(graphDef);
231+
try (Session restoredSession = new Session(restoredGraph)) {
232+
restoredSession.restore(testFolder.resolve("checkpoint").toString());
233+
try (AutoCloseableList<Tensor> oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
234+
AutoCloseableList<Tensor> newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){
235+
assertEquals(oldList.get(0),newList.get(0));
236+
assertEquals(oldList.get(1),newList.get(1));
237+
}
238+
}
239+
}
222240
}
223241
}
224242
assertTrue(Files.exists(testFolder.resolve("checkpoint.index")));
225243
assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001")));
244+
245+
// Cleanup test dir
246+
Files.walk(testFolder)
247+
.sorted(Comparator.reverseOrder())
248+
.map(Path::toFile)
249+
.forEach(File::delete);
226250
}
227251

228252
private static RunOptions fullTraceRunOptions() {

0 commit comments

Comments
 (0)