Skip to content

Fixed the issue in AWS CRT HTTP clients where the connection is shut down unnecessarily #4825

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;
import software.amazon.awssdk.http.HttpStatusFamily;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.async.SdkAsyncHttpResponseHandler;
import software.amazon.awssdk.http.crt.AwsCrtAsyncHttpClient;
Expand All @@ -49,14 +48,17 @@ public final class CrtResponseAdapter implements HttpStreamResponseHandler {
private final SdkAsyncHttpResponseHandler responseHandler;
private final SimplePublisher<ByteBuffer> responsePublisher = new SimplePublisher<>();

private final SdkHttpResponse.Builder responseBuilder = SdkHttpResponse.builder();
private final SdkHttpResponse.Builder responseBuilder;
private final ResponseHandlerHelper responseHandlerHelper;

private CrtResponseAdapter(HttpClientConnection connection,
CompletableFuture<Void> completionFuture,
SdkAsyncHttpResponseHandler responseHandler) {
this.connection = Validate.paramNotNull(connection, "connection");
this.completionFuture = Validate.paramNotNull(completionFuture, "completionFuture");
this.responseHandler = Validate.paramNotNull(responseHandler, "responseHandler");
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, connection);
}

public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnection crtConn,
Expand All @@ -66,18 +68,13 @@ public static HttpStreamResponseHandler toCrtResponseHandler(HttpClientConnectio
}

@Override
public void onResponseHeaders(HttpStream stream, int responseStatusCode, int headerType, HttpHeader[] nextHeaders) {
if (headerType == HttpHeaderBlock.MAIN.getValue()) {
for (HttpHeader h : nextHeaders) {
responseBuilder.appendHeader(h.getName(), h.getValue());
}
}
public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blockType, HttpHeader[] nextHeaders) {
responseHandlerHelper.onResponseHeaders(stream, responseStatusCode, blockType, nextHeaders);
}

@Override
public void onResponseHeadersDone(HttpStream stream, int headerType) {
if (headerType == HttpHeaderBlock.MAIN.getValue()) {
responseBuilder.statusCode(stream.getResponseStatusCode());
responseHandler.onHeaders(responseBuilder.build());
responseHandler.onStream(responsePublisher);
}
Expand All @@ -94,7 +91,7 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {

writeFuture.whenComplete((result, failure) -> {
if (failure != null) {
failResponseHandlerAndFuture(stream, failure);
handlePublisherError(stream, failure);
return;
}

Expand All @@ -116,18 +113,18 @@ public void onResponseComplete(HttpStream stream, int errorCode) {
private void onSuccessfulResponseComplete(HttpStream stream) {
responsePublisher.complete().whenComplete((result, failure) -> {
if (failure != null) {
failResponseHandlerAndFuture(stream, failure);
handlePublisherError(stream, failure);
return;
}

if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) {
connection.shutdown();
}

connection.close();
stream.close();
completionFuture.complete(null);
});

responseHandlerHelper.cleanUpConnectionBasedOnStatusCode(stream);
}

private void handlePublisherError(HttpStream stream, Throwable failure) {
failResponseHandlerAndFuture(stream, failure);
responseHandlerHelper.releaseConnection(stream);
}

private void onFailedResponseComplete(HttpStream stream, HttpException error) {
Expand All @@ -136,14 +133,12 @@ private void onFailedResponseComplete(HttpStream stream, HttpException error) {
Throwable toThrow = wrapWithIoExceptionIfRetryable(error);;
responsePublisher.error(toThrow);
failResponseHandlerAndFuture(stream, toThrow);
responseHandlerHelper.closeConnection(stream);
}

private void failResponseHandlerAndFuture(HttpStream stream, Throwable error) {
callResponseHandlerOnError(error);
completionFuture.completeExceptionally(error);
connection.shutdown();
connection.close();
stream.close();
}

private void callResponseHandlerOnError(Throwable error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;
import software.amazon.awssdk.http.AbortableInputStream;
import software.amazon.awssdk.http.HttpStatusFamily;
import software.amazon.awssdk.http.SdkHttpFullResponse;
import software.amazon.awssdk.http.SdkHttpResponse;
import software.amazon.awssdk.http.crt.AwsCrtHttpClient;
import software.amazon.awssdk.utils.async.InputStreamSubscriber;
import software.amazon.awssdk.utils.async.SimplePublisher;
Expand All @@ -39,17 +39,22 @@
*/
@SdkInternalApi
public final class InputStreamAdaptingHttpStreamResponseHandler implements HttpStreamResponseHandler {
private final SdkHttpFullResponse.Builder responseBuilder = SdkHttpFullResponse.builder();

private volatile InputStreamSubscriber inputStreamSubscriber;
private final SimplePublisher<ByteBuffer> simplePublisher = new SimplePublisher<>();

private final CompletableFuture<SdkHttpFullResponse> requestCompletionFuture;
private final HttpClientConnection crtConn;

private final SdkHttpFullResponse.Builder responseBuilder;
private final ResponseHandlerHelper responseHandlerHelper;

public InputStreamAdaptingHttpStreamResponseHandler(HttpClientConnection crtConn,
CompletableFuture<SdkHttpFullResponse> requestCompletionFuture) {
this.crtConn = crtConn;
this.requestCompletionFuture = requestCompletionFuture;
this.responseBuilder = SdkHttpResponse.builder();
this.responseHandlerHelper = new ResponseHandlerHelper(responseBuilder, crtConn);
}

@Override
Expand All @@ -59,9 +64,8 @@ public void onResponseHeaders(HttpStream stream, int responseStatusCode, int blo
for (HttpHeader h : nextHeaders) {
responseBuilder.appendHeader(h.getName(), h.getValue());
}
responseBuilder.statusCode(responseStatusCode);
}

responseBuilder.statusCode(responseStatusCode);
}

@Override
Expand All @@ -84,15 +88,15 @@ public int onResponseBody(HttpStream stream, byte[] bodyBytesIn) {

writeFuture.whenComplete((result, failure) -> {
if (failure != null) {
failFutureAndCloseConnection(stream, failure);
failFutureAndReleaseConnection(stream, failure);
return;
}

// increment the window upon buffer consumption.
stream.incrementWindow(bodyBytesIn.length);
});

// the bodyBytesIn have not cleared the queues yet, so do let backpressure do its thing.
// Window will be incremented after the subscriber consumes the data, returning 0 here to disable it.
return 0;
}

Expand All @@ -105,11 +109,14 @@ public void onResponseComplete(HttpStream stream, int errorCode) {
}
}

private void failFutureAndReleaseConnection(HttpStream stream, Throwable failure) {
requestCompletionFuture.completeExceptionally(failure);
responseHandlerHelper.releaseConnection(stream);
}

private void failFutureAndCloseConnection(HttpStream stream, Throwable failure) {
requestCompletionFuture.completeExceptionally(failure);
crtConn.shutdown();
crtConn.close();
stream.close();
responseHandlerHelper.closeConnection(stream);
}

private void onFailedResponseComplete(HttpStream stream, int errorCode) {
Expand All @@ -121,16 +128,12 @@ private void onFailedResponseComplete(HttpStream stream, int errorCode) {
}

private void onSuccessfulResponseComplete(HttpStream stream) {
// always close the connection on a 5XX response code.
if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) {
crtConn.shutdown();
}

// For response without a payload, for example, S3 PutObjectResponse, we need to complete the future
// in onResponseComplete callback since onResponseBody will never be invoked.
requestCompletionFuture.complete(responseBuilder.build());

// requestCompletionFuture has been completed at this point, no need to notify hte future
simplePublisher.complete();
crtConn.close();
stream.close();
responseHandlerHelper.cleanUpConnectionBasedOnStatusCode(stream);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.http.crt.internal.response;

import java.util.concurrent.atomic.AtomicBoolean;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpHeader;
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.http.HttpStatusFamily;
import software.amazon.awssdk.http.SdkHttpResponse;

/**
* This is the helper class that contains common logic shared between {@link CrtResponseAdapter} and
* {@link InputStreamAdaptingHttpStreamResponseHandler}.
*
* CRT connection will only be closed, i.e., not reused, in one of the following conditions:
* 1. 5xx server error OR
* 2. It fails to read the response.
*/
@SdkInternalApi
public class ResponseHandlerHelper {

private final SdkHttpResponse.Builder responseBuilder;
private final HttpClientConnection connection;
private AtomicBoolean connectionClosed = new AtomicBoolean(false);

public ResponseHandlerHelper(SdkHttpResponse.Builder responseBuilder, HttpClientConnection connection) {
this.responseBuilder = responseBuilder;
this.connection = connection;
}

public void onResponseHeaders(HttpStream stream, int responseStatusCode, int headerType, HttpHeader[] nextHeaders) {
if (headerType == HttpHeaderBlock.MAIN.getValue()) {
for (HttpHeader h : nextHeaders) {
responseBuilder.appendHeader(h.getName(), h.getValue());
}
responseBuilder.statusCode(responseStatusCode);
}
}

/**
* Release the connection back to the pool so that it can be reused.
*/
public void releaseConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.close();
stream.close();
}
}

/**
* Close the connection completely
*/
public void closeConnection(HttpStream stream) {
if (connectionClosed.compareAndSet(false, true)) {
connection.shutdown();
connection.close();
stream.close();
}
}

public void cleanUpConnectionBasedOnStatusCode(HttpStream stream) {
// always close the connection on a 5XX response code.
if (HttpStatusFamily.of(responseBuilder.statusCode()) == HttpStatusFamily.SERVER_ERROR) {
closeConnection(stream);
} else {
releaseConnection(stream);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package software.amazon.awssdk.http.crt.internal;

import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;

import java.util.concurrent.CompletableFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import software.amazon.awssdk.crt.http.HttpClientConnection;
import software.amazon.awssdk.crt.http.HttpException;
import software.amazon.awssdk.crt.http.HttpHeader;
import software.amazon.awssdk.crt.http.HttpHeaderBlock;
import software.amazon.awssdk.crt.http.HttpStream;
import software.amazon.awssdk.crt.http.HttpStreamResponseHandler;

@ExtendWith(MockitoExtension.class)
public abstract class BaseHttpStreamResponseHandlerTest {
@Mock HttpClientConnection crtConn;
CompletableFuture requestFuture;

@Mock
private HttpStream httpStream;

private HttpStreamResponseHandler responseHandler;

abstract HttpStreamResponseHandler responseHandler();

@BeforeEach
public void setUp() {
requestFuture = new CompletableFuture<>();
responseHandler = responseHandler();
}

@Test
void serverError_shouldShutdownConnection() {
HttpHeader[] httpHeaders = getHttpHeaders();
responseHandler.onResponseHeaders(httpStream, 500, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);

responseHandler.onResponseHeadersDone(httpStream, 0);
responseHandler.onResponseComplete(httpStream, 0);
requestFuture.join();
verify(crtConn).shutdown();
verify(crtConn).close();
verify(httpStream).close();
}

@ParameterizedTest
@ValueSource(ints = { 200, 400, 202, 403 })
void nonServerError_shouldNotShutdownConnection(int statusCode) {
HttpHeader[] httpHeaders = getHttpHeaders();
responseHandler.onResponseHeaders(httpStream, statusCode, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);

responseHandler.onResponseHeadersDone(httpStream, 0);
responseHandler.onResponseComplete(httpStream, 0);

requestFuture.join();
verify(crtConn, never()).shutdown();
verify(crtConn).close();
verify(httpStream).close();
}

@Test
void failedToGetResponse_shouldShutdownConnection() {
HttpHeader[] httpHeaders = getHttpHeaders();
responseHandler.onResponseHeaders(httpStream, 200, HttpHeaderBlock.MAIN.getValue(),
httpHeaders);

responseHandler.onResponseComplete(httpStream, 1);
assertThatThrownBy(() -> requestFuture.join()).hasRootCauseInstanceOf(HttpException.class);
verify(crtConn).shutdown();
verify(crtConn).close();
verify(httpStream).close();
}

private static HttpHeader[] getHttpHeaders() {
HttpHeader[] httpHeaders = new HttpHeader[1];
httpHeaders[0] = new HttpHeader("Content-Length", "1");
return httpHeaders;
}
}
Loading