From edb139cd6e805173b6d3ae968437009181b34367 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 23 Feb 2021 16:28:44 -0500 Subject: [PATCH 1/5] Initial support for restoring a saved session. --- .../src/main/java/org/tensorflow/Session.java | 20 +++++++++++- .../test/java/org/tensorflow/SessionTest.java | 31 ++++++++++++++++--- 2 files changed, 46 insertions(+), 5 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index e156491d09a..42216f31c03 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -519,7 +519,25 @@ public void runInit(){ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); runner() - .addTarget(saverDef.getSaveTensorName()) + .addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) + .run(); + } + + /** + * Restore the actual state of the variables of this session's graph. + * + *

{@code prefix} is the path where the files containing the variables state live, + * followed by the filename prefix. For example, if {@code prefix} is set to + * mymodel/myvariables/variables, then the files are loaded from + * mymodel/myvariables and named variables.data-*-of-* + * + * @param prefix prefix to restore from + */ + public void restore(String prefix) { + SaverDef saverDef = graph.saverDef(); + runner() + .addTarget(saverDef.getRestoreOpName()) .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) .run(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index b1928bff51c..4572b7232fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,10 +20,14 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.io.File; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Comparator; +import java.util.List; + import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -32,6 +36,7 @@ import org.tensorflow.op.linalg.MatMul; import org.tensorflow.op.math.Add; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.NdArrays; @@ -208,21 +213,39 @@ public void runInitByName() { } @Test - public void save() throws IOException { - Path testFolder = Files.createTempDirectory("tf-session-save-test"); + public void saveAndRestore() throws IOException { + Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Variable y = tf.variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); + GraphDef graphDef = g.toGraphDef(); try (Session s = new Session(g)) { s.run(init); s.save(testFolder.resolve("checkpoint").toString()); + try (Graph restoredGraph = new Graph()) { + restoredGraph.importGraphDef(graphDef); + try (Session restoredSession = new Session(restoredGraph)) { + restoredSession.restore(testFolder.resolve("checkpoint").toString()); + try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); + AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){ + assertEquals(oldList.get(0),newList.get(0)); + assertEquals(oldList.get(1),newList.get(1)); + } + } + } } } assertTrue(Files.exists(testFolder.resolve("checkpoint.index"))); assertTrue(Files.exists(testFolder.resolve("checkpoint.data-00000-of-00001"))); + + // Cleanup test dir + Files.walk(testFolder) + .sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .forEach(File::delete); } private static RunOptions fullTraceRunOptions() { From 03ee4e59df296e5f55a62a377496c18c1bb78d7e Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Tue, 23 Feb 2021 17:47:36 -0500 Subject: [PATCH 2/5] Updating surefire version to 3.0.0-M5 to get around a stdout test issue. --- tensorflow-core/tensorflow-core-api/pom.xml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/pom.xml b/tensorflow-core/tensorflow-core-api/pom.xml index f6a2ca2a14c..fd3605d5759 100644 --- a/tensorflow-core/tensorflow-core-api/pom.xml +++ b/tensorflow-core/tensorflow-core-api/pom.xml @@ -373,7 +373,7 @@ maven-surefire-plugin - 2.22.0 + 3.0.0-M5 + ${project.build.directory}/${project.artifactId}-${project.version}-${native.classifier}.jar