diff --git a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java index 7f045cc44d..7fce8003ab 100644 --- a/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java +++ b/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java @@ -68,7 +68,7 @@ *

Package-private for internal use. */ class ChannelPool extends ManagedChannel { - private static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); + @VisibleForTesting static final Logger LOG = Logger.getLogger(ChannelPool.class.getName()); private static final Duration REFRESH_PERIOD = Duration.ofMinutes(50); private final ChannelPoolSettings settings; @@ -381,14 +381,14 @@ void refresh() { * Get and retain a Channel Entry. The returned Entry will have its rpc count incremented, * preventing it from getting recycled. */ - Entry getRetainedEntry(int affinity) { + RetainedEntry getRetainedEntry(int affinity) { // The maximum number of concurrent calls to this method for any given time span is at most 2, // so the loop can actually be 2 times. But going for 5 times for a safety margin for potential // code evolving for (int i = 0; i < 5; i++) { Entry entry = getEntry(affinity); if (entry.retain()) { - return entry; + return new RetainedEntry(entry); } } // It is unlikely to reach here unless the pool code evolves to increase the maximum possible @@ -415,10 +415,37 @@ private Entry getEntry(int affinity) { return localEntries.get(index); } + /** + * This represents the reserved refcount of a single RPC using a channel. It the responsibility of + * that RPC to call release exactly once when it completes to release the Channel. + */ + private static class RetainedEntry { + private final Entry entry; + private final AtomicBoolean wasReleased; + + public RetainedEntry(Entry entry) { + this.entry = entry; + wasReleased = new AtomicBoolean(false); + } + + void release() { + if (!wasReleased.compareAndSet(false, true)) { + Exception e = new IllegalStateException("Entry was already released"); + LOG.log(Level.WARNING, e.getMessage(), e); + return; + } + entry.release(); + } + + public Channel getChannel() { + return entry.channel; + } + } + /** Bundles a gRPC {@link ManagedChannel} with some usage accounting. */ - private static class Entry { + static class Entry { private final ManagedChannel channel; - private final AtomicInteger outstandingRpcs = new AtomicInteger(0); + final AtomicInteger outstandingRpcs = new AtomicInteger(0); private final AtomicInteger maxOutstanding = new AtomicInteger(); // Flag that the channel should be closed once all of the outstanding RPC complete. @@ -511,18 +538,19 @@ public String authority() { public ClientCall newCall( MethodDescriptor methodDescriptor, CallOptions callOptions) { - Entry entry = getRetainedEntry(affinity); + RetainedEntry entry = getRetainedEntry(affinity); - return new ReleasingClientCall<>(entry.channel.newCall(methodDescriptor, callOptions), entry); + return new ReleasingClientCall<>( + entry.getChannel().newCall(methodDescriptor, callOptions), entry); } } /** ClientCall wrapper that makes sure to decrement the outstanding RPC count on completion. */ static class ReleasingClientCall extends SimpleForwardingClientCall { @Nullable private CancellationException cancellationException; - final Entry entry; + final RetainedEntry entry; - public ReleasingClientCall(ClientCall delegate, Entry entry) { + public ReleasingClientCall(ClientCall delegate, RetainedEntry entry) { super(delegate); this.entry = entry; } diff --git a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java index c38d98e91f..ac77f0342a 100644 --- a/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java +++ b/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java @@ -29,9 +29,11 @@ */ package com.google.api.gax.grpc; +import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_RECOGNIZE; import static com.google.api.gax.grpc.testing.FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE; import static com.google.common.truth.Truth.assertThat; +import com.google.api.core.ApiFuture; import com.google.api.gax.grpc.testing.FakeChannelFactory; import com.google.api.gax.grpc.testing.FakeMethodDescriptor; import com.google.api.gax.grpc.testing.FakeServiceGrpc; @@ -40,6 +42,8 @@ import com.google.api.gax.rpc.ServerStreamingCallSettings; import com.google.api.gax.rpc.ServerStreamingCallable; import com.google.api.gax.rpc.StreamController; +import com.google.api.gax.rpc.UnaryCallSettings; +import com.google.api.gax.rpc.UnaryCallable; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -63,6 +67,9 @@ import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.logging.Handler; +import java.util.logging.LogRecord; +import java.util.stream.Collectors; import org.junit.After; import org.junit.Assert; import org.junit.Test; @@ -663,4 +670,72 @@ public void onComplete() {} assertThat(e.getCause()).isInstanceOf(CancellationException.class); assertThat(e.getMessage()).isEqualTo("Call is already cancelled"); } + + @Test + public void testDoubleRelease() throws Exception { + FakeLogHandler logHandler = new FakeLogHandler(); + ChannelPool.LOG.addHandler(logHandler); + + try { + // Create a fake channel pool thats backed by mock channels that simply record invocations + ClientCall mockClientCall = Mockito.mock(ClientCall.class); + ManagedChannel fakeChannel = Mockito.mock(ManagedChannel.class); + Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall); + ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1); + ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel)); + + pool = ChannelPool.create(channelPoolSettings, factory); + + // Construct a fake callable to use the channel pool + ClientContext context = + ClientContext.newBuilder() + .setTransportChannel(GrpcTransportChannel.create(pool)) + .setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT)) + .build(); + + UnaryCallSettings settings = + UnaryCallSettings.newUnaryCallSettingsBuilder().build(); + UnaryCallable callable = + GrpcCallableFactory.createUnaryCallable( + GrpcCallSettings.create(METHOD_RECOGNIZE), settings, context); + + // Start the RPC + ApiFuture rpcFuture = + callable.futureCall(Color.getDefaultInstance(), context.getDefaultCallContext()); + + // Get the server side listener and intentionally close it twice + ArgumentCaptor> clientCallListenerCaptor = + ArgumentCaptor.forClass(ClientCall.Listener.class); + Mockito.verify(mockClientCall).start(clientCallListenerCaptor.capture(), Mockito.any()); + clientCallListenerCaptor.getValue().onClose(Status.INTERNAL, new Metadata()); + clientCallListenerCaptor.getValue().onClose(Status.UNKNOWN, new Metadata()); + + // Ensure that the channel pool properly logged the double call and kept the refCount correct + assertThat(logHandler.getAllMessages()).contains("Entry was already released"); + assertThat(pool.entries.get()).hasSize(1); + ChannelPool.Entry entry = pool.entries.get().get(0); + assertThat(entry.outstandingRpcs.get()).isEqualTo(0); + } finally { + ChannelPool.LOG.removeHandler(logHandler); + } + } + + private static class FakeLogHandler extends Handler { + List records = new ArrayList<>(); + + @Override + public void publish(LogRecord record) { + records.add(record); + } + + @Override + public void flush() {} + + @Override + public void close() throws SecurityException {} + + List getAllMessages() { + return records.stream().map(LogRecord::getMessage).collect(Collectors.toList()); + } + } }