Skip to content

Adds Session.restore to allow loading of checkpoints saved by Session.save #225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.22.0</version>
<version>3.0.0-M5</version>
<executions>
<execution>
<!--
Expand All @@ -389,6 +389,8 @@
</execution>
</executions>
<configuration>
<!-- Activate the use of TCP to transmit events to the plugin -->
<forkNode implementation="org.apache.maven.plugin.surefire.extensions.SurefireForkNodeFactory"/>
<additionalClasspathElements>
<additionalClasspathElement>${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar</additionalClasspathElement>
<!-- Note: the following path is not accessible in deploying profile, so other libraries like
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
Expand All @@ -47,6 +48,7 @@
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Identity;
import org.tensorflow.op.core.NoOp;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.train.Restore;
Expand Down Expand Up @@ -439,15 +441,32 @@ public Output<?>[] whileLoop(
* Return the {@link SaverDef} instance used to save the state of all variables present in
* this graph.
*
* <p/>On the first call of this method, all nodes necessary to save and restore the state of the
* variables are added to the graph. Consequently, any variables that are added to the graph after
* this call could not be saved nor restored using this {@link SaverDef}.
* <p/> The first time this method is called it builds the {@link SaverDef}. If this graph already
* contains a "save/restore_all" operation then it is assumed to contain all necessary saving and
* restoring operations. If that operation does not exist then the graph is mutated to add all
* the nodes necessary to save and restore the state of the graph. Consequently, any variables
* that are added to the graph after this call will not be saved nor restored using this
* {@link SaverDef}.
*
* @return a {@link SaverDef} instance
*/
synchronized SaverDef saverDef() {
if (saverDef == null) {
saverDef = addVariableSaver(this);
// Check to see if this graph has a restore operation
if (operation("save/restore_all") == null) {
// No saver, create one by mutating the graph
saverDef = addVariableSaver(this);
} else {
// This graph already has saving/restoring operations,
// regenerate SaverDef without mutating. The names mirror
// the python implementation for compatibility.
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/saver.py
saverDef = SaverDef.newBuilder()
.setFilenameTensorName("save/filename")
.setSaveTensorName("save/control_dependency")
.setRestoreOpName("save/restore_all")
.build();
}
}
return saverDef;
}
Expand Down Expand Up @@ -798,13 +817,15 @@ private static SaverDef addVariableSaver(Graph graph) {
Constant<TString> varNamesTensor = tf.constant(StdArrays.ndCopyOf(varNames.toArray(tmp)));
Operand<TString> varSlices = tf.zerosLike(varNamesTensor);

Placeholder<TString> saveFilename = tf.placeholder(TString.class);
Placeholder<TString> saveFilename = tf.withName("filename").placeholder(TString.class);
Save saveVariables = tf.train.save(
saveFilename,
varNamesTensor,
varSlices,
varOutputs
);
Identity<TString> id = tf.withControlDependencies(Arrays.asList(saveFilename,saveVariables))
.withName("control_dependency").identity(saveFilename);
Restore restoreVariables = tf.train.restore(
saveFilename,
varNamesTensor,
Expand All @@ -815,11 +836,11 @@ private static SaverDef addVariableSaver(Graph graph) {
for (int i = 0; i < varOutputs.size(); ++i) {
restoreOps.add(tf.assign(varOutputs.get(i), (Operand) restoreVariables.tensors().get(i)));
}
NoOp restoreAll = tf.withControlDependencies(restoreOps).noOp();
NoOp restoreAll = tf.withControlDependencies(restoreOps).withName("restore_all").noOp();

return SaverDef.newBuilder()
.setFilenameTensorName(saveFilename.op().name())
.setSaveTensorName(saveVariables.op().name())
.setSaveTensorName(id.op().name())
.setRestoreOpName(restoreAll.op().name())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,16 +512,35 @@ public void runInit(){
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
*
* <p>Note that this method might alter the underlying graph if it is the first time that one
* of its session is saved, see {@link Graph#saverDef()} for more details.
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
*
* @param prefix prefix to the variable files to save
*/
public void save(String prefix) {
SaverDef saverDef = graph.saverDef();
runner()
.addTarget(saverDef.getSaveTensorName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
runner().addTarget(saverDef.getSaveTensorName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
}

/**
* Restore the actual state of the variables of this session's graph.
*
* <p>{@code prefix} is the path where the files containing the variables state live,
* followed by the filename prefix. For example, if {@code prefix} is set to
* <i>mymodel/myvariables/variables</i>, then the files are loaded from
* <i>mymodel/myvariables</i> and named <i>variables.data-*-of-*</i>
*
* <p>Note that this method might alter the underlying graph if it is the first time that one
* of its sessions is saved, see {@link Graph#saverDef()} for more details.
*
* @param prefix prefix to restore from
*/
public void restore(String prefix) {
SaverDef saverDef = graph.saverDef();
runner().addTarget(saverDef.getRestoreOpName())
.feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix))
.run();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Comparator;

import org.junit.jupiter.api.Test;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Init;
Expand All @@ -32,6 +36,7 @@
import org.tensorflow.op.linalg.MatMul;
import org.tensorflow.op.math.Add;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.GraphDef;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.NdArrays;
Expand Down Expand Up @@ -208,21 +213,40 @@ public void runInitByName() {
}

@Test
public void save() throws IOException {
Path testFolder = Files.createTempDirectory("tf-session-save-test");
public void saveAndRestore() throws IOException {
Path testFolder = Files.createTempDirectory("tf-session-save-restore-test");
try (Graph g = new Graph()) {
Ops tf = Ops.create(g);
Variable<TFloat32> x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Variable<TFloat32> y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class));
Init init = tf.init();

try (Session s = new Session(g)) {
s.run(init);
s.save(testFolder.resolve("checkpoint").toString());
GraphDef graphDef = g.toGraphDef();

try (Graph restoredGraph = new Graph()) {
restoredGraph.importGraphDef(graphDef);
try (Session restoredSession = new Session(restoredGraph)) {
restoredSession.restore(testFolder.resolve("checkpoint").toString());
try (AutoCloseableList<Tensor> oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run());
AutoCloseableList<Tensor> newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){
assertEquals(oldList.get(0),newList.get(0));
assertEquals(oldList.get(1),newList.get(1));
}
}
}
}
}
assertTrue(Files.exists(testFolder.resolve("checkpoint.index")));
assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001")));

// Cleanup test dir
Files.walk(testFolder)
.sorted(Comparator.reverseOrder())
.map(Path::toFile)
.forEach(File::delete);
}

private static RunOptions fullTraceRunOptions() {
Expand Down