diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
index 7f045cc44d..7fce8003ab 100644
--- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
+++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java
@@ -68,7 +68,7 @@
*
Package-private for internal use.
*/
class ChannelPool extends ManagedChannel {
- private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
+ @VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName());
private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50);
private final ChannelPoolSettings settings;
@@ -381,14 +381,14 @@ void refresh() {
* Get and retain a Channel Entry. The returned Entry will have its rpc count incremented,
* preventing it from getting recycled.
*/
- Entry getRetainedEntry(int affinity) {
+ RetainedEntry getRetainedEntry(int affinity) {
// The maximum number of concurrent calls to this method for any given time span is at most 2,
// so the loop can actually be 2 times. But going for 5 times for a safety margin for potential
// code evolving
for (int i = 0; i < 5; i++) {
Entry entry = getEntry(affinity);
if (entry.retain()) {
- return entry;
+ return new RetainedEntry(entry);
}
}
// It is unlikely to reach here unless the pool code evolves to increase the maximum possible
@@ -415,10 +415,37 @@ private Entry getEntry(int affinity) {
return localEntries.get(index);
}
+ /**
+ * This represents the reserved refcount of a single RPC using a channel. It the responsibility of
+ * that RPC to call release exactly once when it completes to release the Channel.
+ */
+ private static class RetainedEntry {
+ private final Entry entry;
+ private final AtomicBoolean wasReleased;
+
+ public RetainedEntry(Entry entry) {
+ this.entry = entry;
+ wasReleased = new AtomicBoolean(false);
+ }
+
+ void release() {
+ if (!wasReleased.compareAndSet(false, true)) {
+ Exception e = new IllegalStateException("Entry was already released");
+ LOG.log(Level.WARNING, e.getMessage(), e);
+ return;
+ }
+ entry.release();
+ }
+
+ public Channel getChannel() {
+ return entry.channel;
+ }
+ }
+
/** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */
- private static class Entry {
+ static class Entry {
private final ManagedChannel channel;
- private final AtomicInteger outstandingRpcs = new AtomicInteger(0);
+ final AtomicInteger outstandingRpcs = new AtomicInteger(0);
private final AtomicInteger maxOutstanding = new AtomicInteger();
// Flag that the channel should be closed once all of the outstanding RPC complete.
@@ -511,18 +538,19 @@ public String authority() {
public ClientCall newCall(
MethodDescriptor methodDescriptor, CallOptions callOptions) {
- Entry entry = getRetainedEntry(affinity);
+ RetainedEntry entry = getRetainedEntry(affinity);
- return new ReleasingClientCall<>(entry.channel.newCall(methodDescriptor, callOptions), entry);
+ return new ReleasingClientCall<>(
+ entry.getChannel().newCall(methodDescriptor, callOptions), entry);
}
}
/** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */
static class ReleasingClientCall extends SimpleForwardingClientCall {
@Nullable private CancellationException cancellationException;
- final Entry entry;
+ final RetainedEntry entry;
- public ReleasingClientCall(ClientCall delegate, Entry entry) {
+ public ReleasingClientCall(ClientCall delegate, RetainedEntry entry) {
super(delegate);
this.entry = entry;
}
diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java
index c38d98e91f..ac77f0342a 100644
--- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java
+++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java
@@ -29,9 +29,11 @@
*/
package com.google.api.gax.grpc;
+import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE;
import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE;
import static com.google.common.truth.Truth.assertThat;
+import com.google.api.core.ApiFuture;
import com.google.api.gax.grpc.testing.FakeChannelFactory;
import com.google.api.gax.grpc.testing.FakeMethodDescriptor;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
@@ -40,6 +42,8 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
+import com.google.api.gax.rpc.UnaryCallSettings;
+import com.google.api.gax.rpc.UnaryCallable;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
@@ -63,6 +67,9 @@
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
+import java.util.logging.Handler;
+import java.util.logging.LogRecord;
+import java.util.stream.Collectors;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
@@ -663,4 +670,72 @@ public void onComplete() {}
assertThat(e.getCause()).isInstanceOf(CancellationException.class);
assertThat(e.getMessage()).isEqualTo("Call is already cancelled");
}
+
+ @Test
+ public void testDoubleRelease() throws Exception {
+ FakeLogHandler logHandler = new FakeLogHandler();
+ ChannelPool.LOG.addHandler(logHandler);
+
+ try {
+ // Create a fake channel pool thats backed by mock channels that simply record invocations
+ ClientCall mockClientCall = Mockito.mock(ClientCall.class);
+ ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class);
+ Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
+ ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
+ ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
+
+ pool = ChannelPool.create(channelPoolSettings, factory);
+
+ // Construct a fake callable to use the channel pool
+ ClientContext context =
+ ClientContext.newBuilder()
+ .setTransportChannel(GrpcTransportChannel.create(pool))
+ .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
+ .build();
+
+ UnaryCallSettings settings =
+ UnaryCallSettings.newUnaryCallSettingsBuilder().build();
+ UnaryCallable callable =
+ GrpcCallableFactory.createUnaryCallable(
+ GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context);
+
+ // Start the RPC
+ ApiFuture rpcFuture =
+ callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext());
+
+ // Get the server side listener and intentionally close it twice
+ ArgumentCaptor> clientCallListenerCaptor =
+ ArgumentCaptor.forClass(ClientCall.Listener.class);
+ Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any());
+ clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata());
+ clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata());
+
+ // Ensure that the channel pool properly logged the double call and kept the refCount correct
+ assertThat(logHandler.getAllMessages()).contains("Entry was already released");
+ assertThat(pool.entries.get()).hasSize(1);
+ ChannelPool.Entry entry = pool.entries.get().get(0);
+ assertThat(entry.outstandingRpcs.get()).isEqualTo(0);
+ } finally {
+ ChannelPool.LOG.removeHandler(logHandler);
+ }
+ }
+
+ private static class FakeLogHandler extends Handler {
+ List records = new ArrayList<>();
+
+ @Override
+ public void publish(LogRecord record) {
+ records.add(record);
+ }
+
+ @Override
+ public void flush() {}
+
+ @Override
+ public void close() throws SecurityException {}
+
+ List getAllMessages() {
+ return records.stream().map(LogRecord::getMessage).collect(Collectors.toList());
+ }
+ }
}