diff --git a/core/src/main/java/io/grpc/internal/RetriableStream.java b/core/src/main/java/io/grpc/internal/RetriableStream.java index ab85ec1bdb5..7f1b41cb2d1 100644 --- a/core/src/main/java/io/grpc/internal/RetriableStream.java +++ b/core/src/main/java/io/grpc/internal/RetriableStream.java @@ -282,14 +282,12 @@ private void drain(Substream substream) { synchronized (lock) { savedState = state; - if (streamStarted) { - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me, to be cancelled - break; - } - if (savedState.cancelled) { - break; - } + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; } if (index == savedState.buffer.size()) { // I'm drained state = savedState.substreamDrained(substream); @@ -326,15 +324,13 @@ public void run() { if (bufferEntry instanceof RetriableStream.StartEntry) { streamStarted = true; } - if (streamStarted) { - savedState = state; - if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { - // committed but not me, to be cancelled - break; - } - if (savedState.cancelled) { - break; - } + savedState = state; + if (savedState.winningSubstream != null && savedState.winningSubstream != substream) { + // committed but not me, to be cancelled + break; + } + if (savedState.cancelled) { + break; } } } @@ -344,6 +340,10 @@ public void run() { return; } + if (!streamStarted) { + // Start stream so inFlightSubStreams is decremented in Sublistener.closed() + substream.stream.start(new Sublistener(substream)); + } substream.stream.cancel( state.winningSubstream == substream ? cancellationStatus : CANCELLED_BECAUSE_COMMITTED); } @@ -484,6 +484,8 @@ public void run() { } if (cancelled) { + // Start stream so inFlightSubStreams is decremented in Sublistener.closed() + newSubstream.stream.start(new Sublistener(newSubstream)); newSubstream.stream.cancel(Status.CANCELLED.withDescription("Unneeded hedging")); return; } @@ -507,6 +509,9 @@ public final void cancel(final Status reason) { Runnable runnable = commit(noopSubstream); if (runnable != null) { + synchronized (lock) { + state = state.substreamDrained(noopSubstream); + } runnable.run(); safeCloseMasterListener(reason, RpcProgress.PROCESSED, new Metadata()); return; diff --git a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java index bbbf36d37c5..e14e71cb2aa 100644 --- a/core/src/test/java/io/grpc/internal/RetriableStreamTest.java +++ b/core/src/test/java/io/grpc/internal/RetriableStreamTest.java @@ -188,7 +188,7 @@ Status prestart() { } } - private final RetriableStream retriableStream = + private RetriableStream retriableStream = newThrottledRetriableStream(null /* throttle */); private final RetriableStream hedgingStream = newThrottledHedgingStream(null /* throttle */); @@ -196,10 +196,13 @@ Status prestart() { private ClientStreamTracer bufferSizeTracer; private RetriableStream newThrottledRetriableStream(Throttle throttle) { + return newThrottledRetriableStream(throttle, MoreExecutors.directExecutor()); + } + + private RetriableStream newThrottledRetriableStream(Throttle throttle, Executor drainer) { return new RecordedRetriableStream( method, new Metadata(), channelBufferUsed, PER_RPC_BUFFER_LIMIT, CHANNEL_BUFFER_LIMIT, - MoreExecutors.directExecutor(), fakeClock.getScheduledExecutorService(), RETRY_POLICY, - null, throttle); + drainer, fakeClock.getScheduledExecutorService(), RETRY_POLICY, null, throttle); } private RetriableStream newThrottledHedgingStream(Throttle throttle) { @@ -598,6 +601,44 @@ public void retry_cancel_closed() { inOrder.verify(retriableStreamRecorder, never()).postCommit(); } + @Test + public void transparentRetry_cancel_race() { + FakeClock drainer = new FakeClock(); + retriableStream = newThrottledRetriableStream(null, drainer.getScheduledExecutorService()); + ClientStream mockStream1 = mock(ClientStream.class); + doReturn(mockStream1).when(retriableStreamRecorder).newSubstream(0); + InOrder inOrder = inOrder(retriableStreamRecorder); + + retriableStream.start(masterListener); + + ArgumentCaptor sublistenerCaptor1 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream1).start(sublistenerCaptor1.capture()); + + // retry, but don't drain + ClientStream mockStream2 = mock(ClientStream.class); + doReturn(mockStream2).when(retriableStreamRecorder).newSubstream(0); + sublistenerCaptor1.getValue().closed( + Status.fromCode(NON_RETRIABLE_STATUS_CODE), MISCARRIED, new Metadata()); + assertEquals(1, drainer.numPendingTasks()); + + // cancel + retriableStream.cancel(Status.CANCELLED); + // drain transparent retry + drainer.runDueTasks(); + inOrder.verify(retriableStreamRecorder).postCommit(); + + ArgumentCaptor sublistenerCaptor2 = + ArgumentCaptor.forClass(ClientStreamListener.class); + verify(mockStream2).start(sublistenerCaptor2.capture()); + ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); + verify(mockStream2).cancel(statusCaptor.capture()); + assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); + assertEquals(CANCELLED_BECAUSE_COMMITTED, statusCaptor.getValue().getDescription()); + sublistenerCaptor2.getValue().closed(statusCaptor.getValue(), PROCESSED, new Metadata()); + verify(masterListener).closed(same(Status.CANCELLED), same(PROCESSED), any(Metadata.class)); + } + @Test public void unretriableClosed_cancel() { ClientStream mockStream1 = mock(ClientStream.class);