Skip to content

Commit f6024dd

Browse files
authored
Keep weak references to eager resources in session (#229)
1 parent c4498eb commit f6024dd

File tree

7 files changed

+228
-17
lines changed

7 files changed

+228
-17
lines changed

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ class EagerOperation extends AbstractOperation {
5353
this.name = name;
5454
this.opHandle = opNativeHandle;
5555
this.outputHandles = outputNativeHandles;
56-
session.attach(opNativeHandle);
57-
session.attach(outputNativeHandles);
5856
this.outputTensors = new AtomicReferenceArray<>(outputNativeHandles.length);
5957
}
6058

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperationBuilder.java

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,7 @@ final class EagerOperationBuilder implements OperationBuilder {
6565
@Override
6666
public EagerOperation build() {
6767
TFE_TensorHandle[] tensorHandles = execute(opHandle, session);
68-
EagerOperation operation =
69-
new EagerOperation(session, opHandle, tensorHandles, type, name);
70-
// Release our reference to the native op handle now that we transferred its
71-
// ownership to the EagerOperation
72-
session.detach(opHandle);
73-
return operation;
68+
return new EagerOperation(session, opHandle, tensorHandles, type, name);
7469
}
7570

7671
@Override

tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.bytedeco.javacpp.BytePointer;
2525
import org.bytedeco.javacpp.Pointer;
2626
import org.bytedeco.javacpp.PointerScope;
27+
import org.tensorflow.internal.WeakPointerScope;
2728
import org.tensorflow.internal.c_api.TFE_Context;
2829
import org.tensorflow.internal.c_api.TFE_ContextOptions;
2930
import org.tensorflow.internal.c_api.TF_Status;
@@ -310,13 +311,45 @@ TFE_Context nativeHandle() {
310311
return nativeHandle;
311312
}
312313

314+
/**
315+
* Attach the list of native resources to this eager session scope.
316+
*
317+
* <p>When the eager session is closed (i.e. by calling {@link #close()} explicitly or
318+
* implicitly via try-with-resources), all native resources attached to the session will be
319+
* released as well, unless so other references are {@link Pointer#retainReference() retaining}
320+
* them.</p>
321+
*
322+
* <p>Attached resources can still be garbage collected though if their associated {@link Pointer}
323+
* is no longer reachable in Java, independently of their reference count. Therefore, it is
324+
* assumed that these resources are not required by the native library once the Java client no
325+
* longer needs them.</p>
326+
*
327+
* <p>Attaching a resource already attached to this session will have no effect.</p>
328+
*
329+
* @param resources resources to attach to the session
330+
*/
313331
void attach(Pointer... resources) {
314332
checkSession();
315333
for (Pointer r : resources) {
316334
nativeResources.attach(r);
317335
}
318336
}
319337

338+
/**
339+
* Detach a list of resources from this eager session scope.
340+
*
341+
* <p>Detached native resources will prevent them to be automatically released when the session is
342+
* closed.</p>
343+
*
344+
* <p>Note though that this method will decrement the reference count of each resources being
345+
* detached, which may automatically released them if that count reaches 0. Therefore,
346+
* invoking {@link Pointer#retainReference()} prior to this call on any resource that must remain
347+
* valid after being detached might be required.</p>
348+
*
349+
* <p>Detaching a resource that is not attached to this session will have no effect.</p>
350+
*
351+
* @param resources resources to detach from the session
352+
*/
320353
void detach(Pointer... resources) {
321354
checkSession();
322355
for (Pointer r : resources) {
@@ -326,14 +359,12 @@ void detach(Pointer... resources) {
326359

327360
private static volatile EagerSession defaultSession = null;
328361

329-
private final PointerScope nativeResources;
362+
private final WeakPointerScope nativeResources;
330363
private TFE_Context nativeHandle;
331364

332365
private EagerSession(Options options) {
333-
try (PointerScope scope = new PointerScope()) {
334-
this.nativeResources = scope.extend();
335-
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
336-
}
366+
this.nativeResources = new WeakPointerScope();
367+
this.nativeHandle = allocate(options.async, options.devicePlacementPolicy.code, options.config);
337368
}
338369

339370
private void checkSession() {
@@ -363,7 +394,7 @@ private static TFE_Context allocate(boolean async, int devicePlacementPolicy, Co
363394
TFE_ContextOptionsSetDevicePlacementPolicy(opts, devicePlacementPolicy);
364395
TFE_Context context = TFE_NewContext(opts, status);
365396
status.throwExceptionIfNotOK();
366-
return context;
397+
return context.retainReference();
367398
}
368399
}
369400

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
package org.tensorflow.internal;
2+
3+
import java.util.Collections;
4+
import java.util.Set;
5+
import java.util.WeakHashMap;
6+
import org.bytedeco.javacpp.Pointer;
7+
8+
/**
9+
* A minimalist pointer scope only keeping weak references to its elements.
10+
*
11+
* <p>As opposed to {@link org.bytedeco.javacpp.PointerScope}, instances of this class will not
12+
* prevent the garbage collector to free the memory of a pointer that is no longer reachable, even
13+
* if it has been attached to the scope.</p>
14+
*
15+
* <p>When the scope is closed, all pointers that are still valid will be automatically deallocated
16+
* while those already garbage-collected will be ignored.</p>
17+
*/
18+
public class WeakPointerScope implements AutoCloseable {
19+
20+
/**
21+
* Attach a pointer to this scope.
22+
*
23+
* <p>Pointers attached to the scope will be automatically freed once the scope is closed, unless
24+
* they have been already released by the garbage collector</p>
25+
*
26+
* <p>It this {@code pointer} was already attached to this scope, this method has no effect.</p>
27+
*
28+
* @param pointer pointer to attach
29+
* @throws IllegalStateException if that scope has already been closed
30+
*/
31+
public void attach(Pointer pointer) {
32+
checkScope();
33+
if (pointers.add(pointer)) {
34+
pointer.retainReference();
35+
}
36+
}
37+
38+
/**
39+
* Detach a pointer from this scope.
40+
*
41+
* <p>Detaching a pointer from the scope will prevent its memory to be freed when closing the
42+
* scope.</p>
43+
*
44+
* <p>If this {@code pointer} is not attached to this scope, this method has no effect.</p>
45+
*
46+
* @param pointer pointer to detach
47+
* @throws IllegalStateException if that scope has already been closed
48+
*/
49+
public void detach(Pointer pointer) {
50+
checkScope();
51+
if (pointers.remove(pointer)) {
52+
pointer.releaseReference();
53+
}
54+
}
55+
56+
@Override
57+
public synchronized void close() {
58+
checkScope();
59+
pointers.forEach(Pointer::releaseReference);
60+
pointers = null;
61+
}
62+
63+
private Set<Pointer> pointers = Collections.newSetFromMap(new WeakHashMap<>());
64+
65+
private void checkScope() {
66+
if (pointers == null) {
67+
throw new IllegalStateException("Pointer scope has been closed");
68+
}
69+
}
70+
}

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ public class EagerOperationTest {
3535
public void failToCreateIfSessionIsClosed() {
3636
EagerSession session = EagerSession.create();
3737
session.close();
38-
try {
39-
new EagerOperation(session, null, null, "Add", "add");
38+
try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) {
39+
EagerOperation op =
40+
opBuilder(session, "Const", "OutputAttrs")
41+
.setAttr("dtype", t.dataType())
42+
.setAttr("value", t)
43+
.build();
4044
fail();
4145
} catch (IllegalStateException e) {
4246
// expected

tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerSessionTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ public void cleanupResourceInBackground() {
5959
sleep(50); // allow some time to the background thread for cleaning up resources
6060

6161
long before = Pointer.totalBytes();
62-
s.detach(ref.retainReference());
6362
ref = null;
6463
System.gc();
6564
sleep(50); // allow some time to the background thread for cleaning up resources
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
package org.tensorflow.internal;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
import static org.junit.jupiter.api.Assertions.assertFalse;
5+
import static org.junit.jupiter.api.Assertions.assertNull;
6+
import static org.junit.jupiter.api.Assertions.assertThrows;
7+
import static org.junit.jupiter.api.Assertions.assertTrue;
8+
9+
import org.bytedeco.javacpp.IntPointer;
10+
import org.bytedeco.javacpp.Pointer;
11+
import org.junit.jupiter.api.Test;
12+
import org.tensorflow.EagerSession;
13+
14+
public class WeakPointerScopeTest {
15+
16+
@Test
17+
public void resourcesAttachedAreFreedOnScopeClose() {
18+
Pointer pointer = new IntPointer(10L);
19+
assertEquals(0, pointer.referenceCount());
20+
21+
try (WeakPointerScope scope = new WeakPointerScope()) {
22+
scope.attach(pointer);
23+
assertEquals(1, pointer.referenceCount());
24+
}
25+
assertTrue(pointer.isNull());
26+
}
27+
28+
@Test
29+
public void resourcesDetachedAreNotFreedOnScopeCloseWhenRetained() {
30+
Pointer pointer = new IntPointer(10L);
31+
32+
try (WeakPointerScope scope = new WeakPointerScope()) {
33+
scope.attach(pointer);
34+
scope.detach(pointer.retainReference());
35+
}
36+
assertFalse(pointer.isNull());
37+
assertEquals(1, pointer.referenceCount());
38+
pointer.deallocate();
39+
}
40+
41+
@Test
42+
public void resourcesDetachedAreFreedWhenNotRetained() {
43+
Pointer pointer = new IntPointer(10L);
44+
45+
try (WeakPointerScope scope = new WeakPointerScope()) {
46+
scope.attach(pointer);
47+
48+
scope.detach(pointer);
49+
assertTrue(pointer.isNull());
50+
}
51+
}
52+
53+
@Test
54+
public void attachingResourceMoreThanOnceHasNoEffect() {
55+
Pointer pointer = new IntPointer(10L);
56+
57+
try (WeakPointerScope scope = new WeakPointerScope()) {
58+
scope.attach(pointer);
59+
scope.attach(pointer);
60+
assertEquals(1, pointer.referenceCount());
61+
62+
Pointer pointerCopy = new Pointer(pointer);
63+
assertEquals(1, pointerCopy.referenceCount());
64+
scope.attach(pointerCopy);
65+
assertEquals(1, pointerCopy.referenceCount());
66+
}
67+
assertTrue(pointer.isNull());
68+
}
69+
70+
@Test
71+
public void detachingUnattachedResourceHasNoEffect() {
72+
Pointer pointer = new IntPointer(10L);
73+
pointer.retainReference();
74+
assertEquals(1, pointer.referenceCount());
75+
76+
try (WeakPointerScope scope = new WeakPointerScope()) {
77+
scope.detach(pointer);
78+
assertEquals(1, pointer.referenceCount());
79+
}
80+
assertFalse(pointer.isNull());
81+
pointer.deallocate();
82+
}
83+
84+
@Test
85+
public void operationOnClosedScopeFails() {
86+
Pointer pointer = new IntPointer(10L);
87+
WeakPointerScope scope = new WeakPointerScope();
88+
scope.close();
89+
90+
assertThrows(IllegalStateException.class, () -> scope.attach(pointer));
91+
assertThrows(IllegalStateException.class, () -> scope.detach(pointer));
92+
assertThrows(IllegalStateException.class, () -> scope.close());
93+
94+
pointer.deallocate();
95+
}
96+
97+
@Test
98+
public void attachingResourceDoesNotPreventItToBeGarbageCollected() throws InterruptedException {
99+
try (WeakPointerScope scope = new WeakPointerScope()) {
100+
Pointer pointer = new IntPointer(10L);
101+
scope.attach(pointer);
102+
System.gc();
103+
Thread.sleep(50);
104+
105+
long before = Pointer.totalBytes();
106+
pointer = null;
107+
System.gc();
108+
Thread.sleep(50);
109+
long after = Pointer.totalBytes();
110+
111+
assertEquals(4 * 10L, before - after);
112+
}
113+
}
114+
}

0 commit comments

Comments
 (0)