diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java index b2865ac49531..23ae027c8992 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/CrtResponseAdapter.java @@ -27,7 +27,6 @@ import software.amazon.awssdk.crt.http.HttpHeaderBlock; import software.amazon.awssdk.crt.http.HttpStream; import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; -import software.amazon.awssdk.http.HttpStatusFamily; import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient; @@ -49,7 +48,8 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler { private final SdkAsyncHttpResponseHandler responseHandler; private final SimplePublisher responsePublisher = new SimplePublisher<>(); - private final SdkHttpResponse.Builder responseBuilder = SdkHttpResponse.builder(); + private final SdkHttpResponse.Builder responseBuilder; + private final ResponseHandlerHelper responseHandlerHelper; private CrtResponseAdapter(HttpClientConnection connection, CompletableFuture completionFuture, @@ -57,6 +57,8 @@ private CrtResponseAdapter(HttpClientConnection connection, this.connection = Validate.paramNotNull(connection, "connection"); this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture"); this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler"); + this.responseBuilder = SdkHttpResponse.builder(); + this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection); } public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn, @@ -66,18 +68,13 @@ public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnectio } @Override - public void onResponseHeaders(HttpStream stream, int responseStatusCode, int headerType, HttpHeader[] nextHeaders) { - if (headerType == HttpHeaderBlock.MAIN.getValue()) { - for (HttpHeader h : nextHeaders) { - responseBuilder.appendHeader(h.getName(), h.getValue()); - } - } + public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blockType, HttpHeader[] nextHeaders) { + responseHandlerHelper.onResponseHeaders(stream, responseStatusCode, blockType, nextHeaders); } @Override public void onResponseHeadersDone(HttpStream stream, int headerType) { if (headerType == HttpHeaderBlock.MAIN.getValue()) { - responseBuilder.statusCode(stream.getResponseStatusCode()); responseHandler.onHeaders(responseBuilder.build()); responseHandler.onStream(responsePublisher); } @@ -94,7 +91,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { writeFuture.whenComplete((result, failure) -> { if (failure != null) { - failResponseHandlerAndFuture(stream, failure); + handlePublisherError(stream, failure); return; } @@ -116,18 +113,18 @@ public void onResponseComplete(HttpStream stream, int errorCode) { private void onSuccessfulResponseComplete(HttpStream stream) { responsePublisher.complete().whenComplete((result, failure) -> { if (failure != null) { - failResponseHandlerAndFuture(stream, failure); + handlePublisherError(stream, failure); return; } - - if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) { - connection.shutdown(); - } - - connection.close(); - stream.close(); completionFuture.complete(null); }); + + responseHandlerHelper.cleanUpConnectionBasedOnStatusCode(stream); + } + + private void handlePublisherError(HttpStream stream, Throwable failure) { + failResponseHandlerAndFuture(stream, failure); + responseHandlerHelper.releaseConnection(stream); } private void onFailedResponseComplete(HttpStream stream, HttpException error) { @@ -136,14 +133,12 @@ private void onFailedResponseComplete(HttpStream stream, HttpException error) { Throwable toThrow = wrapWithIoExceptionIfRetryable(error);; responsePublisher.error(toThrow); failResponseHandlerAndFuture(stream, toThrow); + responseHandlerHelper.closeConnection(stream); } private void failResponseHandlerAndFuture(HttpStream stream, Throwable error) { callResponseHandlerOnError(error); completionFuture.completeExceptionally(error); - connection.shutdown(); - connection.close(); - stream.close(); } private void callResponseHandlerOnError(Throwable error) { diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java index b1bf4462be89..9e12d4d4679e 100644 --- a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/InputStreamAdaptingHttpStreamResponseHandler.java @@ -28,8 +28,8 @@ import software.amazon.awssdk.crt.http.HttpStream; import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; import software.amazon.awssdk.http.AbortableInputStream; -import software.amazon.awssdk.http.HttpStatusFamily; import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.SdkHttpResponse; import software.amazon.awssdk.http.crt.AwsCrtHttpClient; import software.amazon.awssdk.utils.async.InputStreamSubscriber; import software.amazon.awssdk.utils.async.SimplePublisher; @@ -39,17 +39,22 @@ */ @SdkInternalApi public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler { - private final SdkHttpFullResponse.Builder responseBuilder = SdkHttpFullResponse.builder(); + private volatile InputStreamSubscriber inputStreamSubscriber; private final SimplePublisher simplePublisher = new SimplePublisher<>(); private final CompletableFuture requestCompletionFuture; private final HttpClientConnection crtConn; + private final SdkHttpFullResponse.Builder responseBuilder; + private final ResponseHandlerHelper responseHandlerHelper; + public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn, CompletableFuture requestCompletionFuture) { this.crtConn = crtConn; this.requestCompletionFuture = requestCompletionFuture; + this.responseBuilder = SdkHttpResponse.builder(); + this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn); } @Override @@ -59,9 +64,8 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blo for (HttpHeader h : nextHeaders) { responseBuilder.appendHeader(h.getName(), h.getValue()); } + responseBuilder.statusCode(responseStatusCode); } - - responseBuilder.statusCode(responseStatusCode); } @Override @@ -84,7 +88,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { writeFuture.whenComplete((result, failure) -> { if (failure != null) { - failFutureAndCloseConnection(stream, failure); + failFutureAndReleaseConnection(stream, failure); return; } @@ -92,7 +96,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) { stream.incrementWindow(bodyBytesIn.length); }); - // the bodyBytesIn have not cleared the queues yet, so do let backpressure do its thing. + // Window will be incremented after the subscriber consumes the data, returning 0 here to disable it. return 0; } @@ -105,11 +109,14 @@ public void onResponseComplete(HttpStream stream, int errorCode) { } } + private void failFutureAndReleaseConnection(HttpStream stream, Throwable failure) { + requestCompletionFuture.completeExceptionally(failure); + responseHandlerHelper.releaseConnection(stream); + } + private void failFutureAndCloseConnection(HttpStream stream, Throwable failure) { requestCompletionFuture.completeExceptionally(failure); - crtConn.shutdown(); - crtConn.close(); - stream.close(); + responseHandlerHelper.closeConnection(stream); } private void onFailedResponseComplete(HttpStream stream, int errorCode) { @@ -121,16 +128,12 @@ private void onFailedResponseComplete(HttpStream stream, int errorCode) { } private void onSuccessfulResponseComplete(HttpStream stream) { - // always close the connection on a 5XX response code. - if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) { - crtConn.shutdown(); - } - // For response without a payload, for example, S3 PutObjectResponse, we need to complete the future // in onResponseComplete callback since onResponseBody will never be invoked. requestCompletionFuture.complete(responseBuilder.build()); + + // requestCompletionFuture has been completed at this point, no need to notify the future simplePublisher.complete(); - crtConn.close(); - stream.close(); + responseHandlerHelper.cleanUpConnectionBasedOnStatusCode(stream); } } diff --git a/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java new file mode 100644 index 000000000000..6b4a9c91231d --- /dev/null +++ b/http-clients/aws-crt-client/src/main/java/software/amazon/awssdk/http/crt/internal/response/ResponseHandlerHelper.java @@ -0,0 +1,85 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.crt.internal.response; + +import java.util.concurrent.atomic.AtomicBoolean; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.crt.http.HttpClientConnection; +import software.amazon.awssdk.crt.http.HttpHeader; +import software.amazon.awssdk.crt.http.HttpHeaderBlock; +import software.amazon.awssdk.crt.http.HttpStream; +import software.amazon.awssdk.http.HttpStatusFamily; +import software.amazon.awssdk.http.SdkHttpResponse; + +/** + * This is the helper class that contains common logic shared between {@link CrtResponseAdapter} and + * {@link InputStreamAdaptingHttpStreamResponseHandler}. + * + * CRT connection will only be closed, i.e., not reused, in one of the following conditions: + * 1. 5xx server error OR + * 2. It fails to read the response. + */ +@SdkInternalApi +public class ResponseHandlerHelper { + + private final SdkHttpResponse.Builder responseBuilder; + private final HttpClientConnection connection; + private AtomicBoolean connectionClosed = new AtomicBoolean(false); + + public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) { + this.responseBuilder = responseBuilder; + this.connection = connection; + } + + public void onResponseHeaders(HttpStream stream, int responseStatusCode, int headerType, HttpHeader[] nextHeaders) { + if (headerType == HttpHeaderBlock.MAIN.getValue()) { + for (HttpHeader h : nextHeaders) { + responseBuilder.appendHeader(h.getName(), h.getValue()); + } + responseBuilder.statusCode(responseStatusCode); + } + } + + /** + * Release the connection back to the pool so that it can be reused. + */ + public void releaseConnection(HttpStream stream) { + if (connectionClosed.compareAndSet(false, true)) { + connection.close(); + stream.close(); + } + } + + /** + * Close the connection completely + */ + public void closeConnection(HttpStream stream) { + if (connectionClosed.compareAndSet(false, true)) { + connection.shutdown(); + connection.close(); + stream.close(); + } + } + + public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) { + // always close the connection on a 5XX response code. + if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) { + closeConnection(stream); + } else { + releaseConnection(stream); + } + } +} diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java new file mode 100644 index 000000000000..baf126759054 --- /dev/null +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/BaseHttpStreamResponseHandlerTest.java @@ -0,0 +1,103 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.crt.internal; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.crt.http.HttpClientConnection; +import software.amazon.awssdk.crt.http.HttpException; +import software.amazon.awssdk.crt.http.HttpHeader; +import software.amazon.awssdk.crt.http.HttpHeaderBlock; +import software.amazon.awssdk.crt.http.HttpStream; +import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; + +@ExtendWith(MockitoExtension.class) +public abstract class BaseHttpStreamResponseHandlerTest { + @Mock HttpClientConnection crtConn; + CompletableFuture requestFuture; + + @Mock + private HttpStream httpStream; + + private HttpStreamResponseHandler responseHandler; + + abstract HttpStreamResponseHandler responseHandler(); + + @BeforeEach + public void setUp() { + requestFuture = new CompletableFuture<>(); + responseHandler = responseHandler(); + } + + @Test + void serverError_shouldShutdownConnection() { + HttpHeader[] httpHeaders = getHttpHeaders(); + responseHandler.onResponseHeaders(httpStream, 500, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + + responseHandler.onResponseHeadersDone(httpStream, 0); + responseHandler.onResponseComplete(httpStream, 0); + requestFuture.join(); + verify(crtConn).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + } + + @ParameterizedTest + @ValueSource(ints = { 200, 400, 202, 403 }) + void nonServerError_shouldNotShutdownConnection(int statusCode) { + HttpHeader[] httpHeaders = getHttpHeaders(); + responseHandler.onResponseHeaders(httpStream, statusCode, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + + responseHandler.onResponseHeadersDone(httpStream, 0); + responseHandler.onResponseComplete(httpStream, 0); + + requestFuture.join(); + verify(crtConn, never()).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + } + + @Test + void failedToGetResponse_shouldShutdownConnection() { + HttpHeader[] httpHeaders = getHttpHeaders(); + responseHandler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(), + httpHeaders); + + responseHandler.onResponseComplete(httpStream, 1); + assertThatThrownBy(() -> requestFuture.join()).hasRootCauseInstanceOf(HttpException.class); + verify(crtConn).shutdown(); + verify(crtConn).close(); + verify(httpStream).close(); + } + + private static HttpHeader[] getHttpHeaders() { + HttpHeader[] httpHeaders = new HttpHeader[1]; + httpHeaders[0] = new HttpHeader("Content-Length", "1"); + return httpHeaders; + } +} diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java new file mode 100644 index 000000000000..2b18f684b728 --- /dev/null +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/CrtResponseHandlerTest.java @@ -0,0 +1,40 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.crt.internal; + +import java.util.concurrent.CompletableFuture; +import java.util.function.Function; +import org.mockito.Mockito; +import software.amazon.awssdk.core.http.HttpResponseHandler; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.internal.http.async.AsyncResponseHandler; +import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler; +import software.amazon.awssdk.http.crt.internal.response.CrtResponseAdapter; +import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler; + +public class CrtResponseHandlerTest extends BaseHttpStreamResponseHandlerTest { + + @Override + HttpStreamResponseHandler responseHandler() { + AsyncResponseHandler responseHandler = new AsyncResponseHandler<>((response, + executionAttributes) -> null, Function.identity(), new ExecutionAttributes()); + + responseHandler.prepare(); + return CrtResponseAdapter.toCrtResponseHandler(crtConn, requestFuture, responseHandler); + } +} diff --git a/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java new file mode 100644 index 000000000000..f76e799d2cdb --- /dev/null +++ b/http-clients/aws-crt-client/src/test/java/software/amazon/awssdk/http/crt/internal/InputStreamAdaptingHttpStreamResponseHandlerTest.java @@ -0,0 +1,43 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.http.crt.internal; + +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import software.amazon.awssdk.crt.http.HttpClientConnection; +import software.amazon.awssdk.crt.http.HttpHeader; +import software.amazon.awssdk.crt.http.HttpHeaderBlock; +import software.amazon.awssdk.crt.http.HttpStream; +import software.amazon.awssdk.crt.http.HttpStreamResponseHandler; +import software.amazon.awssdk.http.SdkHttpFullResponse; +import software.amazon.awssdk.http.crt.internal.response.InputStreamAdaptingHttpStreamResponseHandler; + +public class InputStreamAdaptingHttpStreamResponseHandlerTest extends BaseHttpStreamResponseHandlerTest { + + @Override + HttpStreamResponseHandler responseHandler() { + return new InputStreamAdaptingHttpStreamResponseHandler(crtConn, requestFuture); + } +}