diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index a5c2df84026..9f87fd8b95e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -53,8 +53,6 @@ class EagerOperation extends AbstractOperation { this.name = name; this.opHandle = opNativeHandle; this.outputHandles = outputNativeHandles; - session.attach(opNativeHandle); - session.attach(outputNativeHandles); this.outputTensors = new AtomicReferenceArray<>(outputNativeHandles.length); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java index a865300bc5a..f1dd6216a79 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java @@ -65,12 +65,7 @@ final class EagerOperationBuilder implements OperationBuilder { @Override public EagerOperation build() { TFE_TensorHandle[] tensorHandles = execute(opHandle, session); - EagerOperation operation = - new EagerOperation(session, opHandle, tensorHandles, type, name); - // Release our reference to the native op handle now that we transferred its - // ownership to the EagerOperation - session.detach(opHandle); - return operation; + return new EagerOperation(session, opHandle, tensorHandles, type, name); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 75bc12b5a6c..8e7465388a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -24,6 +24,7 @@ import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.internal.WeakPointerScope; import org.tensorflow.internal.c_api.TFE_Context; import org.tensorflow.internal.c_api.TFE_ContextOptions; import org.tensorflow.internal.c_api.TF_Status; @@ -310,6 +311,23 @@ TFE_Context nativeHandle() { return nativeHandle; } + /** + * Attach the list of native resources to this eager session scope. + * + *

When the eager session is closed (i.e. by calling {@link #close()} explicitly or + * implicitly via try-with-resources), all native resources attached to the session will be + * released as well, unless so other references are {@link Pointer#retainReference() retaining} + * them.

+ * + *

Attached resources can still be garbage collected though if their associated {@link Pointer} + * is no longer reachable in Java, independently of their reference count. Therefore, it is + * assumed that these resources are not required by the native library once the Java client no + * longer needs them.

+ * + *

Attaching a resource already attached to this session will have no effect.

+ * + * @param resources resources to attach to the session + */ void attach(Pointer... resources) { checkSession(); for (Pointer r : resources) { @@ -317,6 +335,21 @@ void attach(Pointer... resources) { } } + /** + * Detach a list of resources from this eager session scope. + * + *

Detached native resources will prevent them to be automatically released when the session is + * closed.

+ * + *

Note though that this method will decrement the reference count of each resources being + * detached, which may automatically released them if that count reaches 0. Therefore, + * invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain + * valid after being detached might be required.

+ * + *

Detaching a resource that is not attached to this session will have no effect.

+ * + * @param resources resources to detach from the session + */ void detach(Pointer... resources) { checkSession(); for (Pointer r : resources) { @@ -326,14 +359,12 @@ void detach(Pointer... resources) { private static volatile EagerSession defaultSession = null; - private final PointerScope nativeResources; + private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; private EagerSession(Options options) { - try (PointerScope scope = new PointerScope()) { - this.nativeResources = scope.extend(); - this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); - } + this.nativeResources = new WeakPointerScope(); + this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config); } private void checkSession() { @@ -363,7 +394,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy); TFE_Context context = TFE_NewContext(opts, status); status.throwExceptionIfNotOK(); - return context; + return context.retainReference(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java new file mode 100644 index 00000000000..f12e97c2702 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/WeakPointerScope.java @@ -0,0 +1,70 @@ +package org.tensorflow.internal; + +import java.util.Collections; +import java.util.Set; +import java.util.WeakHashMap; +import org.bytedeco.javacpp.Pointer; + +/** + * A minimalist pointer scope only keeping weak references to its elements. + * + *

As opposed to {@link org.bytedeco.javacpp.PointerScope}, instances of this class will not + * prevent the garbage collector to free the memory of a pointer that is no longer reachable, even + * if it has been attached to the scope.

+ * + *

When the scope is closed, all pointers that are still valid will be automatically deallocated + * while those already garbage-collected will be ignored.

+ */ +public class WeakPointerScope implements AutoCloseable { + + /** + * Attach a pointer to this scope. + * + *

Pointers attached to the scope will be automatically freed once the scope is closed, unless + * they have been already released by the garbage collector

+ * + *

It this {@code pointer} was already attached to this scope, this method has no effect.

+ * + * @param pointer pointer to attach + * @throws IllegalStateException if that scope has already been closed + */ + public void attach(Pointer pointer) { + checkScope(); + if (pointers.add(pointer)) { + pointer.retainReference(); + } + } + + /** + * Detach a pointer from this scope. + * + *

Detaching a pointer from the scope will prevent its memory to be freed when closing the + * scope.

+ * + *

If this {@code pointer} is not attached to this scope, this method has no effect.

+ * + * @param pointer pointer to detach + * @throws IllegalStateException if that scope has already been closed + */ + public void detach(Pointer pointer) { + checkScope(); + if (pointers.remove(pointer)) { + pointer.releaseReference(); + } + } + + @Override + public synchronized void close() { + checkScope(); + pointers.forEach(Pointer::releaseReference); + pointers = null; + } + + private Set pointers = Collections.newSetFromMap(new WeakHashMap<>()); + + private void checkScope() { + if (pointers == null) { + throw new IllegalStateException("Pointer scope has been closed"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 2920fbdf59f..38714b86599 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -35,8 +35,12 @@ public class EagerOperationTest { public void failToCreateIfSessionIsClosed() { EagerSession session = EagerSession.create(); session.close(); - try { - new EagerOperation(session, null, null, "Add", "add"); + try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + EagerOperation op = + opBuilder(session, "Const", "OutputAttrs") + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build(); fail(); } catch (IllegalStateException e) { // expected diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java index 7ac54213a0b..77325d50dcc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java @@ -59,7 +59,6 @@ public void cleanupResourceInBackground() { sleep(50); // allow some time to the background thread for cleaning up resources long before = Pointer.totalBytes(); - s.detach(ref.retainReference()); ref = null; System.gc(); sleep(50); // allow some time to the background thread for cleaning up resources diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java new file mode 100644 index 00000000000..815a1200c89 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/internal/WeakPointerScopeTest.java @@ -0,0 +1,114 @@ +package org.tensorflow.internal; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.bytedeco.javacpp.IntPointer; +import org.bytedeco.javacpp.Pointer; +import org.junit.jupiter.api.Test; +import org.tensorflow.EagerSession; + +public class WeakPointerScopeTest { + + @Test + public void resourcesAttachedAreFreedOnScopeClose() { + Pointer pointer = new IntPointer(10L); + assertEquals(0, pointer.referenceCount()); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + assertEquals(1, pointer.referenceCount()); + } + assertTrue(pointer.isNull()); + } + + @Test + public void resourcesDetachedAreNotFreedOnScopeCloseWhenRetained() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + scope.detach(pointer.retainReference()); + } + assertFalse(pointer.isNull()); + assertEquals(1, pointer.referenceCount()); + pointer.deallocate(); + } + + @Test + public void resourcesDetachedAreFreedWhenNotRetained() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + + scope.detach(pointer); + assertTrue(pointer.isNull()); + } + } + + @Test + public void attachingResourceMoreThanOnceHasNoEffect() { + Pointer pointer = new IntPointer(10L); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.attach(pointer); + scope.attach(pointer); + assertEquals(1, pointer.referenceCount()); + + Pointer pointerCopy = new Pointer(pointer); + assertEquals(1, pointerCopy.referenceCount()); + scope.attach(pointerCopy); + assertEquals(1, pointerCopy.referenceCount()); + } + assertTrue(pointer.isNull()); + } + + @Test + public void detachingUnattachedResourceHasNoEffect() { + Pointer pointer = new IntPointer(10L); + pointer.retainReference(); + assertEquals(1, pointer.referenceCount()); + + try (WeakPointerScope scope = new WeakPointerScope()) { + scope.detach(pointer); + assertEquals(1, pointer.referenceCount()); + } + assertFalse(pointer.isNull()); + pointer.deallocate(); + } + + @Test + public void operationOnClosedScopeFails() { + Pointer pointer = new IntPointer(10L); + WeakPointerScope scope = new WeakPointerScope(); + scope.close(); + + assertThrows(IllegalStateException.class, () -> scope.attach(pointer)); + assertThrows(IllegalStateException.class, () -> scope.detach(pointer)); + assertThrows(IllegalStateException.class, () -> scope.close()); + + pointer.deallocate(); + } + + @Test + public void attachingResourceDoesNotPreventItToBeGarbageCollected() throws InterruptedException { + try (WeakPointerScope scope = new WeakPointerScope()) { + Pointer pointer = new IntPointer(10L); + scope.attach(pointer); + System.gc(); + Thread.sleep(50); + + long before = Pointer.totalBytes(); + pointer = null; + System.gc(); + Thread.sleep(50); + long after = Pointer.totalBytes(); + + assertEquals(4 * 10L, before - after); + } + } +}