From 2f01391e0eb8959d2d6c9ca79b178b799256789b Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Tue, 25 Oct 2022 15:15:50 -0700 Subject: [PATCH 1/4] Add support for S3AsyncClient --- pom.xml | 4 +- .../AmazonSQSExtendedAsyncClient.java | 516 ++++++++++++++++++ .../AmazonSQSExtendedAsyncClientBase.java | 244 +++++++++ .../AmazonSQSExtendedClient.java | 183 +------ .../AmazonSQSExtendedClientUtil.java | 168 ++++++ .../ExtendedAsyncClientConfiguration.java | 182 ++++++ .../AmazonSQSExtendedClientTest.java | 47 +- 7 files changed, 1157 insertions(+), 187 deletions(-) create mode 100644 src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java create mode 100644 src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java create mode 100644 src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java create mode 100644 src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java diff --git a/pom.xml b/pom.xml index 7c4fb2a..abecb3d 100644 --- a/pom.xml +++ b/pom.xml @@ -6,7 +6,7 @@ com.amazonaws amazon-sqs-java-extended-client-lib - 2.0.2 + 2.1.0 jar Amazon SQS Extended Client Library for Java An extension to the Amazon SQS client that enables sending and receiving messages up to 2GB via Amazon S3. @@ -55,7 +55,7 @@ software.amazon.payloadoffloading payloadoffloading-common - 2.1.2 + 2.2.0 diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java new file mode 100644 index 0000000..d9a5992 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -0,0 +1,516 @@ +package com.amazon.sqs.javamessaging; + +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import software.amazon.awssdk.awscore.AwsRequest; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.core.util.VersionInfo; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchResponse; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.PurgeQueueRequest; +import software.amazon.awssdk.services.sqs.model.PurgeQueueResponse; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageResponse; +import software.amazon.awssdk.utils.StringUtils; +import software.amazon.payloadoffloading.PayloadStoreAsync; +import software.amazon.payloadoffloading.S3AsyncDao; +import software.amazon.payloadoffloading.S3BackedPayloadStoreAsync; +import software.amazon.payloadoffloading.Util; + +/** + * Amazon SQS Extended Async Client extends the functionality of Amazon Async SQS + * client. All service calls made using this client are asynchronous, and will return + * immediately with a {@link CompletableFuture} that completes when the operation + * completes or when an exception is thrown. + * + *

+ * The Amazon SQS extended client enables sending and receiving large messages + * via Amazon S3. You can use this library to: + *

+ * + * + */ +public class AmazonSQSExtendedAsyncClient extends AmazonSQSExtendedAsyncClientBase implements SqsAsyncClient { + static final String USER_AGENT_NAME = AmazonSQSExtendedAsyncClient.class.getSimpleName(); + static final String USER_AGENT_VERSION = VersionInfo.SDK_VERSION; + + private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedAsyncClient.class); + private ExtendedAsyncClientConfiguration clientConfiguration; + private PayloadStoreAsync payloadStore; + + /** + * Constructs a new Amazon SQS extended async client to invoke service methods on + * Amazon SQS with extended functionality using the specified Amazon SQS + * client object. + * + *

+ * All service calls made using this client are asynchronous, and will return + * immediately with a {@link CompletableFuture} that completes when the operation + * completes or when an exception is thrown. + * + * @param sqsClient + * The Amazon SQS async client to use to connect to Amazon SQS. + */ + public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient) { + this(sqsClient, new ExtendedAsyncClientConfiguration()); + } + + /** + * Constructs a new Amazon SQS extended client to invoke service methods on + * Amazon SQS with extended functionality using the specified Amazon SQS + * client object. + * + *

+ * All service calls made using this client are asynchronous, and will return + * immediately with a {@link CompletableFuture} that completes when the operation + * completes or when an exception is thrown. + * + * @param sqsClient + * The Amazon SQS async client to use to connect to Amazon SQS. + * @param extendedClientConfig + * The extended client configuration options controlling the + * functionality of this client. + */ + public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient, + ExtendedAsyncClientConfiguration extendedClientConfig) { + super(sqsClient); + this.clientConfiguration = new ExtendedAsyncClientConfiguration(extendedClientConfig); + S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(), + clientConfiguration.getServerSideEncryptionStrategy(), + clientConfiguration.getObjectCannedACL()); + this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName()); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture sendMessage(SendMessageRequest sendMessageRequest) { + // TODO: Clone request since it's modified in this method and will cause issues if the client reuses request + // object. + if (sendMessageRequest == null) { + String errorMessage = "sendMessageRequest cannot be null."; + LOG.error(errorMessage); + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(SdkClientException.create(errorMessage)); + return futureEx; + } + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + return super.sendMessage(sendMessageRequest); + } + + if (StringUtils.isEmpty(sendMessageRequest.messageBody())) { + String errorMessage = "messageBody cannot be null or empty."; + LOG.error(errorMessage); + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(SdkClientException.create(errorMessage)); + return futureEx; + } + + //Check message attributes for ExtendedClient related constraints + try { + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); + } catch (SdkClientException e) { + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(e); + return futureEx; + } + + if (clientConfiguration.isAlwaysThroughS3() + || isLarge(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest)) { + return storeMessageInS3(sendMessageRequest) + .thenCompose(modifiedRequest -> super.sendMessage(modifiedRequest)); + } + + return super.sendMessage(sendMessageRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture receiveMessage(ReceiveMessageRequest receiveMessageRequest) { + // TODO: Clone request since it's modified in this method and will cause issues if the client reuses request + // object. + if (receiveMessageRequest == null) { + String errorMessage = "receiveMessageRequest cannot be null."; + LOG.error(errorMessage); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(SdkClientException.create(errorMessage)); + return future; + } + + ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); + appendUserAgent(receiveMessageRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + return super.receiveMessage(receiveMessageRequestBuilder.build()); + } + + // Remove before adding to avoid any duplicates + List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); + receiveMessageRequest = receiveMessageRequestBuilder.build(); + + return super.receiveMessage(receiveMessageRequest) + .thenCompose(receiveMessageResponse -> { + ReceiveMessageResponse.Builder receiveMessageResponseBuilder = receiveMessageResponse.toBuilder(); + + List messages = receiveMessageResponse.messages(); + List> modifiedMessageFutures = new ArrayList<>(messages.size()); + for (Message message : messages) { + Message.Builder messageBuilder = message.toBuilder(); + + // For each received message check if they are stored in S3. + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent( + message.messageAttributes()); + if (largePayloadAttributeName.isPresent()) { + // Not S3 + modifiedMessageFutures.add(CompletableFuture.completedFuture(messageBuilder.build())); + } else { + // In S3 + final String largeMessagePointer = message.body() + .replace("com.amazon.sqs.javamessaging.MessageS3Pointer", + "software.amazon.payloadoffloading.PayloadS3Pointer"); + + // Retrieve original payload + modifiedMessageFutures.add(payloadStore.getOriginalPayload(largeMessagePointer) + .thenApply(originalPayload -> { + // Set original payload + messageBuilder.body(originalPayload); + + // Remove the additional attribute before returning the message + // to user. + Map messageAttributes = new HashMap<>( + message.messageAttributes()); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageBuilder.messageAttributes(messageAttributes); + + // Embed s3 object pointer in the receipt handle. + String modifiedReceiptHandle = embedS3PointerInReceiptHandle( + message.receiptHandle(), + largeMessagePointer); + messageBuilder.receiptHandle(modifiedReceiptHandle); + + return messageBuilder.build(); + })); + } + } + + // Convert list of message futures to a future list of messages. + return CompletableFuture.allOf( + modifiedMessageFutures.toArray(new CompletableFuture[modifiedMessageFutures.size()])) + .thenApply(v -> modifiedMessageFutures.stream() + .map(CompletableFuture::join) + .collect(Collectors.toList())); + }) + .thenApply(modifiedMessages -> { + // Build response with modified message list. + ReceiveMessageResponse.Builder receiveMessageResponseBuilder = ReceiveMessageResponse.builder(); + receiveMessageResponseBuilder.messages(modifiedMessages); + return receiveMessageResponseBuilder.build(); + }); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture deleteMessage(DeleteMessageRequest deleteMessageRequest) { + if (deleteMessageRequest == null) { + String errorMessage = "deleteMessageRequest cannot be null."; + LOG.error(errorMessage); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(SdkClientException.create(errorMessage)); + return future; + } + + DeleteMessageRequest.Builder deleteMessageRequestBuilder = deleteMessageRequest.toBuilder(); + appendUserAgent(deleteMessageRequestBuilder); + + String receiptHandle = deleteMessageRequest.receiptHandle(); + String origReceiptHandle = receiptHandle; + String messagePointer = null; + + // Update original receipt handle if needed. + if (clientConfiguration.isPayloadSupportEnabled() && isS3ReceiptHandle(receiptHandle)) { + origReceiptHandle = getOrigReceiptHandle(receiptHandle); + + // Delete pay load from S3 if needed + if (clientConfiguration.doesCleanupS3Payload()) { + messagePointer = getMessagePointerFromModifiedReceiptHandle(receiptHandle); + } + } + + // The actual message to delete from SQS. + deleteMessageRequestBuilder.receiptHandle(origReceiptHandle); + + // Check if message is in S3 or only in SQS. + if (messagePointer == null) { + // Delete only from SQS + return super.deleteMessage(deleteMessageRequestBuilder.build()); + } + + // Delete from SQS first, then S3. + final String messageToDeletePointer = messagePointer; + return super.deleteMessage(deleteMessageRequestBuilder.build()) + .thenCompose(deleteMessageResponse -> + payloadStore.deleteOriginalPayload(messageToDeletePointer) + .thenApply(v -> deleteMessageResponse)); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture changeMessageVisibility( + ChangeMessageVisibilityRequest changeMessageVisibilityRequest) { + + ChangeMessageVisibilityRequest.Builder changeMessageVisibilityRequestBuilder = + changeMessageVisibilityRequest.toBuilder(); + if (isS3ReceiptHandle(changeMessageVisibilityRequest.receiptHandle())) { + changeMessageVisibilityRequestBuilder.receiptHandle( + getOrigReceiptHandle(changeMessageVisibilityRequest.receiptHandle())); + } + return amazonSqsToBeExtended.changeMessageVisibility(changeMessageVisibilityRequestBuilder.build()); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture sendMessageBatch( + SendMessageBatchRequest sendMessageBatchRequestIn) { + + if (sendMessageBatchRequestIn == null) { + String errorMessage = "sendMessageBatchRequest cannot be null."; + LOG.error(errorMessage); + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(SdkClientException.create(errorMessage)); + return futureEx; + } + + SendMessageBatchRequest.Builder sendMessageBatchRequestBuilder = sendMessageBatchRequestIn.toBuilder(); + appendUserAgent(sendMessageBatchRequestBuilder); + SendMessageBatchRequest sendMessageBatchRequest = sendMessageBatchRequestBuilder.build(); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + return super.sendMessageBatch(sendMessageBatchRequest); + } + + List> batchEntryFutures = new ArrayList<>( + sendMessageBatchRequest.entries().size()); + boolean hasS3Entries = false; + for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) { + //Check message attributes for ExtendedClient related constraints + try { + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); + } catch (SdkClientException e) { + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(e); + return futureEx; + } + + if (clientConfiguration.isAlwaysThroughS3() + || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) { + batchEntryFutures.add(storeMessageInS3(entry)); + } else { + batchEntryFutures.add(CompletableFuture.completedFuture(entry)); + } + } + + if (!hasS3Entries) { + return super.sendMessageBatch(sendMessageBatchRequest); + } + + // Convert list of entry futures to a future list of entries. + return CompletableFuture.allOf( + batchEntryFutures.toArray(new CompletableFuture[batchEntryFutures.size()])) + .thenApply(v -> batchEntryFutures.stream() + .map(CompletableFuture::join) + .collect(Collectors.toList())) + .thenCompose(batchEntries -> { + SendMessageBatchRequest modifiedBatchRequest = + sendMessageBatchRequest.toBuilder().entries(batchEntries).build(); + return super.sendMessageBatch(modifiedBatchRequest); + }); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture deleteMessageBatch( + DeleteMessageBatchRequest deleteMessageBatchRequest) { + + if (deleteMessageBatchRequest == null) { + String errorMessage = "deleteMessageBatchRequest cannot be null."; + LOG.error(errorMessage); + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(SdkClientException.create(errorMessage)); + return futureEx; + } + + DeleteMessageBatchRequest.Builder deleteMessageBatchRequestBuilder = deleteMessageBatchRequest.toBuilder(); + appendUserAgent(deleteMessageBatchRequestBuilder); + + if (!clientConfiguration.isPayloadSupportEnabled()) { + return super.deleteMessageBatch(deleteMessageBatchRequest); + } + + List entries = new ArrayList<>(deleteMessageBatchRequest.entries().size()); + for (DeleteMessageBatchRequestEntry entry : deleteMessageBatchRequest.entries()) { + DeleteMessageBatchRequestEntry.Builder entryBuilder = entry.toBuilder(); + String receiptHandle = entry.receiptHandle(); + String origReceiptHandle = receiptHandle; + + // Update original receipt handle if needed + if (isS3ReceiptHandle(receiptHandle)) { + origReceiptHandle = getOrigReceiptHandle(receiptHandle); + // Delete s3 payload if needed + if (clientConfiguration.doesCleanupS3Payload()) { + String messagePointer = getMessagePointerFromModifiedReceiptHandle(receiptHandle); + payloadStore.deleteOriginalPayload(messagePointer); + } + } + + entryBuilder.receiptHandle(origReceiptHandle); + entries.add(entryBuilder.build()); + } + + deleteMessageBatchRequestBuilder.entries(entries); + return super.deleteMessageBatch(deleteMessageBatchRequestBuilder.build()); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture changeMessageVisibilityBatch( + ChangeMessageVisibilityBatchRequest changeMessageVisibilityBatchRequest) { + List entries = new ArrayList<>( + changeMessageVisibilityBatchRequest.entries().size()); + for (ChangeMessageVisibilityBatchRequestEntry entry : changeMessageVisibilityBatchRequest.entries()) { + ChangeMessageVisibilityBatchRequestEntry.Builder entryBuilder = entry.toBuilder(); + if (isS3ReceiptHandle(entry.receiptHandle())) { + entryBuilder.receiptHandle(getOrigReceiptHandle(entry.receiptHandle())); + } + entries.add(entryBuilder.build()); + } + + return amazonSqsToBeExtended.changeMessageVisibilityBatch( + changeMessageVisibilityBatchRequest.toBuilder().entries(entries).build()); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture purgeQueue(PurgeQueueRequest purgeQueueRequest) { + LOG.warn("Calling purgeQueue deletes SQS messages without deleting their payload from S3."); + + if (purgeQueueRequest == null) { + String errorMessage = "purgeQueueRequest cannot be null."; + LOG.error(errorMessage); + CompletableFuture futureEx = new CompletableFuture<>(); + futureEx.completeExceptionally(SdkClientException.create(errorMessage)); + return futureEx; + } + + PurgeQueueRequest.Builder purgeQueueRequestBuilder = purgeQueueRequest.toBuilder(); + appendUserAgent(purgeQueueRequestBuilder); + + return super.purgeQueue(purgeQueueRequestBuilder.build()); + } + + private CompletableFuture storeMessageInS3(SendMessageBatchRequestEntry batchEntry) { + // Read the content of the message from message body + String messageContentStr = batchEntry.messageBody(); + + Long messageContentSize = Util.getStringSizeInBytes(messageContentStr); + + SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder(); + + batchEntryBuilder.messageAttributes( + updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + return payloadStore.storeOriginalPayload(messageContentStr) + .thenApply(largeMessagePointer -> { + batchEntryBuilder.messageBody(largeMessagePointer); + return batchEntryBuilder.build(); + }); + } + + private CompletableFuture storeMessageInS3(SendMessageRequest sendMessageRequest) { + // Read the content of the message from message body + String messageContentStr = sendMessageRequest.messageBody(); + + Long messageContentSize = Util.getStringSizeInBytes(messageContentStr); + + SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); + + sendMessageRequestBuilder.messageAttributes( + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize, + clientConfiguration.usesLegacyReservedAttributeName())); + + // Store the message content in S3. + return payloadStore.storeOriginalPayload(messageContentStr) + .thenApply(largeMessagePointer -> { + sendMessageRequestBuilder.messageBody(largeMessagePointer); + return sendMessageRequestBuilder.build(); + }); + } + + private static T appendUserAgent(final T builder) { + return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION); + } +} diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java new file mode 100644 index 0000000..2547297 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java @@ -0,0 +1,244 @@ +package com.amazon.sqs.javamessaging; + +import java.util.concurrent.CompletableFuture; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.model.AddPermissionRequest; +import software.amazon.awssdk.services.sqs.model.AddPermissionResponse; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchResponse; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest; +import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityResponse; +import software.amazon.awssdk.services.sqs.model.CreateQueueRequest; +import software.amazon.awssdk.services.sqs.model.CreateQueueResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse; +import software.amazon.awssdk.services.sqs.model.DeleteQueueRequest; +import software.amazon.awssdk.services.sqs.model.DeleteQueueResponse; +import software.amazon.awssdk.services.sqs.model.GetQueueAttributesRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueAttributesResponse; +import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest; +import software.amazon.awssdk.services.sqs.model.GetQueueUrlResponse; +import software.amazon.awssdk.services.sqs.model.ListDeadLetterSourceQueuesRequest; +import software.amazon.awssdk.services.sqs.model.ListDeadLetterSourceQueuesResponse; +import software.amazon.awssdk.services.sqs.model.ListQueueTagsRequest; +import software.amazon.awssdk.services.sqs.model.ListQueueTagsResponse; +import software.amazon.awssdk.services.sqs.model.ListQueuesRequest; +import software.amazon.awssdk.services.sqs.model.ListQueuesResponse; +import software.amazon.awssdk.services.sqs.model.PurgeQueueRequest; +import software.amazon.awssdk.services.sqs.model.PurgeQueueResponse; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.services.sqs.model.RemovePermissionRequest; +import software.amazon.awssdk.services.sqs.model.RemovePermissionResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageResponse; +import software.amazon.awssdk.services.sqs.model.SetQueueAttributesRequest; +import software.amazon.awssdk.services.sqs.model.SetQueueAttributesResponse; +import software.amazon.awssdk.services.sqs.model.TagQueueRequest; +import software.amazon.awssdk.services.sqs.model.TagQueueResponse; +import software.amazon.awssdk.services.sqs.model.UntagQueueRequest; +import software.amazon.awssdk.services.sqs.model.UntagQueueResponse; + +abstract class AmazonSQSExtendedAsyncClientBase implements SqsAsyncClient { + SqsAsyncClient amazonSqsToBeExtended; + + public AmazonSQSExtendedAsyncClientBase(SqsAsyncClient sqsClient) { + amazonSqsToBeExtended = sqsClient; + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture sendMessage(SendMessageRequest sendMessageRequest) { + return amazonSqsToBeExtended.sendMessage(sendMessageRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture receiveMessage(ReceiveMessageRequest receiveMessageRequest) { + return amazonSqsToBeExtended.receiveMessage(receiveMessageRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture deleteMessage(DeleteMessageRequest deleteMessageRequest) { + return amazonSqsToBeExtended.deleteMessage(deleteMessageRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture setQueueAttributes( + SetQueueAttributesRequest setQueueAttributesRequest) { + return amazonSqsToBeExtended.setQueueAttributes(setQueueAttributesRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture changeMessageVisibilityBatch( + ChangeMessageVisibilityBatchRequest changeMessageVisibilityBatchRequest) { + return amazonSqsToBeExtended.changeMessageVisibilityBatch(changeMessageVisibilityBatchRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture changeMessageVisibility( + ChangeMessageVisibilityRequest changeMessageVisibilityRequest) { + return amazonSqsToBeExtended.changeMessageVisibility(changeMessageVisibilityRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture getQueueUrl(GetQueueUrlRequest getQueueUrlRequest) { + return amazonSqsToBeExtended.getQueueUrl(getQueueUrlRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture removePermission( + RemovePermissionRequest removePermissionRequest) { + return amazonSqsToBeExtended.removePermission(removePermissionRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture getQueueAttributes( + GetQueueAttributesRequest getQueueAttributesRequest) { + return amazonSqsToBeExtended.getQueueAttributes(getQueueAttributesRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture sendMessageBatch( + SendMessageBatchRequest sendMessageBatchRequest) { + return amazonSqsToBeExtended.sendMessageBatch(sendMessageBatchRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture purgeQueue(PurgeQueueRequest purgeQueueRequest) { + return amazonSqsToBeExtended.purgeQueue(purgeQueueRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture listDeadLetterSourceQueues( + ListDeadLetterSourceQueuesRequest listDeadLetterSourceQueuesRequest) { + return amazonSqsToBeExtended.listDeadLetterSourceQueues(listDeadLetterSourceQueuesRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture deleteQueue(DeleteQueueRequest deleteQueueRequest) { + return amazonSqsToBeExtended.deleteQueue(deleteQueueRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture listQueues(ListQueuesRequest listQueuesRequest) { + return amazonSqsToBeExtended.listQueues(listQueuesRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture listQueues() { + return amazonSqsToBeExtended.listQueues(); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture deleteMessageBatch( + DeleteMessageBatchRequest deleteMessageBatchRequest) { + return amazonSqsToBeExtended.deleteMessageBatch(deleteMessageBatchRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture createQueue(CreateQueueRequest createQueueRequest) { + return amazonSqsToBeExtended.createQueue(createQueueRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture addPermission(AddPermissionRequest addPermissionRequest) { + return amazonSqsToBeExtended.addPermission(addPermissionRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture listQueueTags(final ListQueueTagsRequest listQueueTagsRequest) { + return amazonSqsToBeExtended.listQueueTags(listQueueTagsRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture tagQueue(final TagQueueRequest tagQueueRequest) { + return amazonSqsToBeExtended.tagQueue(tagQueueRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public CompletableFuture untagQueue(final UntagQueueRequest untagQueueRequest) { + return amazonSqsToBeExtended.untagQueue(untagQueueRequest); + } + + /** + * {@inheritDoc} + */ + @Override + public String serviceName() { + return amazonSqsToBeExtended.serviceName(); + } + + /** + * {@inheritDoc} + */ + @Override + public void close() { + amazonSqsToBeExtended.close(); + } +} diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 7520706..65f8a79 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -15,26 +15,28 @@ package com.amazon.sqs.javamessaging; -import java.lang.UnsupportedOperationException; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize; + import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import software.amazon.awssdk.awscore.AwsRequest; -import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.awscore.exception.AwsServiceException; -import software.amazon.awssdk.core.ApiName; -import software.amazon.awssdk.core.SdkBytes; import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.exception.SdkException; import software.amazon.awssdk.core.util.VersionInfo; import software.amazon.awssdk.services.sqs.SqsClient; - import software.amazon.awssdk.services.sqs.model.BatchEntryIdsNotDistinctException; import software.amazon.awssdk.services.sqs.model.BatchRequestTooLongException; import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest; @@ -70,7 +72,6 @@ import software.amazon.awssdk.services.sqs.model.SqsException; import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException; import software.amazon.awssdk.utils.StringUtils; -import software.amazon.payloadoffloading.PayloadS3Pointer; import software.amazon.payloadoffloading.PayloadStore; import software.amazon.payloadoffloading.S3BackedPayloadStore; import software.amazon.payloadoffloading.S3Dao; @@ -101,9 +102,6 @@ public class AmazonSQSExtendedClient extends AmazonSQSExtendedClientBase impleme static final String USER_AGENT_VERSION = VersionInfo.SDK_VERSION; private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedClient.class); - static final String LEGACY_RESERVED_ATTRIBUTE_NAME = "SQSLargePayloadSize"; - static final List RESERVED_ATTRIBUTE_NAMES = Arrays.asList(LEGACY_RESERVED_ATTRIBUTE_NAME, - SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); private ExtendedClientConfiguration clientConfiguration; private PayloadStore payloadStore; @@ -204,9 +202,10 @@ public SendMessageResponse sendMessage(SendMessageRequest sendMessageRequest) { } //Check message attributes for ExtendedClient related constraints - checkMessageAttributes(sendMessageRequest.messageAttributes()); + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); - if (clientConfiguration.isAlwaysThroughS3() || isLarge(sendMessageRequest)) { + if (clientConfiguration.isAlwaysThroughS3() + || isLarge(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest)) { sendMessageRequest = storeMessageInS3(sendMessageRequest); } return super.sendMessage(sendMessageRequest); @@ -320,8 +319,8 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag } //Remove before adding to avoid any duplicates List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames()); - messageAttributeNames.removeAll(RESERVED_ATTRIBUTE_NAMES); - messageAttributeNames.addAll(RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); + messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames); receiveMessageRequest = receiveMessageRequestBuilder.build(); @@ -344,7 +343,7 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag // remove the additional attribute before returning the message // to user. Map messageAttributes = new HashMap<>(message.messageAttributes()); - messageAttributes.keySet().removeAll(RESERVED_ATTRIBUTE_NAMES); + messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES); messageBuilder.messageAttributes(messageAttributes); // Embed s3 object pointer in the receipt handle. @@ -620,9 +619,10 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes boolean hasS3Entries = false; for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) { //Check message attributes for ExtendedClient related constraints - checkMessageAttributes(entry.messageAttributes()); + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); - if (clientConfiguration.isAlwaysThroughS3() || isLarge(entry)) { + if (clientConfiguration.isAlwaysThroughS3() + || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) { entry = storeMessageInS3(entry); hasS3Entries = true; } @@ -837,121 +837,6 @@ public PurgeQueueResponse purgeQueue(PurgeQueueRequest purgeQueueRequest) return super.purgeQueue(purgeQueueRequestBuilder.build()); } - private void checkMessageAttributes(Map messageAttributes) { - int msgAttributesSize = getMsgAttributesSize(messageAttributes); - if (msgAttributesSize > clientConfiguration.getPayloadSizeThreshold()) { - String errorMessage = "Total size of Message attributes is " + msgAttributesSize - + " bytes which is larger than the threshold of " + clientConfiguration.getPayloadSizeThreshold() - + " Bytes. Consider including the payload in the message body instead of message attributes."; - LOG.error(errorMessage); - throw SdkClientException.create(errorMessage); - } - - int messageAttributesNum = messageAttributes.size(); - if (messageAttributesNum > SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES) { - String errorMessage = "Number of message attributes [" + messageAttributesNum - + "] exceeds the maximum allowed for large-payload messages [" - + SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES + "]."; - LOG.error(errorMessage); - throw SdkClientException.create(errorMessage); - } - Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(messageAttributes); - - if (largePayloadAttributeName.isPresent()) { - String errorMessage = "Message attribute name " + largePayloadAttributeName.get() - + " is reserved for use by SQS extended client."; - LOG.error(errorMessage); - throw SdkClientException.create(errorMessage); - } - } - - /** - * TODO: Wrap the message pointer as-is to the receiptHandle so that it can be generic - * and does not use any LargeMessageStore implementation specific details. - */ - private String embedS3PointerInReceiptHandle(String receiptHandle, String pointer) { - PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(pointer); - String s3MsgBucketName = s3Pointer.getS3BucketName(); - String s3MsgKey = s3Pointer.getS3Key(); - - String modifiedReceiptHandle = SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + s3MsgBucketName - + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER - + s3MsgKey + SQSExtendedClientConstants.S3_KEY_MARKER + receiptHandle; - return modifiedReceiptHandle; - } - - private String getOrigReceiptHandle(String receiptHandle) { - int secondOccurence = receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER, - receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER) + 1); - return receiptHandle.substring(secondOccurence + SQSExtendedClientConstants.S3_KEY_MARKER.length()); - } - - private String getFromReceiptHandleByMarker(String receiptHandle, String marker) { - int firstOccurence = receiptHandle.indexOf(marker); - int secondOccurence = receiptHandle.indexOf(marker, firstOccurence + 1); - return receiptHandle.substring(firstOccurence + marker.length(), secondOccurence); - } - - private boolean isS3ReceiptHandle(String receiptHandle) { - return receiptHandle.contains(SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER) - && receiptHandle.contains(SQSExtendedClientConstants.S3_KEY_MARKER); - } - - private String getMessagePointerFromModifiedReceiptHandle(String receiptHandle) { - String s3MsgBucketName = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER); - String s3MsgKey = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_KEY_MARKER); - - PayloadS3Pointer payloadS3Pointer = new PayloadS3Pointer(s3MsgBucketName, s3MsgKey); - return payloadS3Pointer.toJson(); - } - - private boolean isLarge(SendMessageRequest sendMessageRequest) { - int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.messageAttributes()); - long msgBodySize = Util.getStringSizeInBytes(sendMessageRequest.messageBody()); - long totalMsgSize = msgAttributesSize + msgBodySize; - return (totalMsgSize > clientConfiguration.getPayloadSizeThreshold()); - } - - private boolean isLarge(SendMessageBatchRequestEntry batchEntry) { - int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes()); - long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody()); - long totalMsgSize = msgAttributesSize + msgBodySize; - return (totalMsgSize > clientConfiguration.getPayloadSizeThreshold()); - } - - private Optional getReservedAttributeNameIfPresent(Map msgAttributes) { - String reservedAttributeName = null; - if (msgAttributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) { - reservedAttributeName = SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME; - } else if (msgAttributes.containsKey(LEGACY_RESERVED_ATTRIBUTE_NAME)) { - reservedAttributeName = LEGACY_RESERVED_ATTRIBUTE_NAME; - } - return Optional.ofNullable(reservedAttributeName); - } - - private int getMsgAttributesSize(Map msgAttributes) { - int totalMsgAttributesSize = 0; - for (Map.Entry entry : msgAttributes.entrySet()) { - totalMsgAttributesSize += Util.getStringSizeInBytes(entry.getKey()); - - MessageAttributeValue entryVal = entry.getValue(); - if (entryVal.dataType() != null) { - totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.dataType()); - } - - String stringVal = entryVal.stringValue(); - if (stringVal != null) { - totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.stringValue()); - } - - SdkBytes binaryVal = entryVal.binaryValue(); - if (binaryVal != null) { - totalMsgAttributesSize += binaryVal.asByteArray().length; - } - } - return totalMsgAttributesSize; - } - private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEntry batchEntry) { // Read the content of the message from message body @@ -962,7 +847,8 @@ private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEnt SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder(); batchEntryBuilder.messageAttributes( - updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize)); + updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize, + clientConfiguration.usesLegacyReservedAttributeName())); // Store the message content in S3. String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr); @@ -981,7 +867,8 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); sendMessageRequestBuilder.messageAttributes( - updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize)); + updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize, + clientConfiguration.usesLegacyReservedAttributeName())); // Store the message content in S3. String largeMessagePointer = payloadStore.storeOriginalPayload(messageContentStr); @@ -990,31 +877,7 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques return sendMessageRequestBuilder.build(); } - private Map updateMessageAttributePayloadSize( - Map messageAttributes, Long messageContentSize) { - Map updatedMessageAttributes = new HashMap<>(messageAttributes); - - // Add a new message attribute as a flag - MessageAttributeValue.Builder messageAttributeValueBuilder = MessageAttributeValue.builder(); - messageAttributeValueBuilder.dataType("Number"); - messageAttributeValueBuilder.stringValue(messageContentSize.toString()); - MessageAttributeValue messageAttributeValue = messageAttributeValueBuilder.build(); - - if (!clientConfiguration.usesLegacyReservedAttributeName()) { - updatedMessageAttributes.put(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, messageAttributeValue); - } else { - updatedMessageAttributes.put(LEGACY_RESERVED_ATTRIBUTE_NAME, messageAttributeValue); - } - return updatedMessageAttributes; - } - - @SuppressWarnings("unchecked") private static T appendUserAgent(final T builder) { - return (T) builder - .overrideConfiguration( - AwsRequestOverrideConfiguration.builder() - .addApiName(ApiName.builder().name(USER_AGENT_NAME) - .version(USER_AGENT_VERSION).build()) - .build()); + return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION); } } diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java new file mode 100644 index 0000000..5dfc282 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java @@ -0,0 +1,168 @@ +package com.amazon.sqs.javamessaging; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import software.amazon.awssdk.awscore.AwsRequest; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.ApiName; +import software.amazon.awssdk.core.SdkBytes; +import software.amazon.awssdk.core.exception.SdkClientException; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.payloadoffloading.PayloadS3Pointer; +import software.amazon.payloadoffloading.Util; + +public class AmazonSQSExtendedClientUtil { + private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedClientUtil.class); + + public static final String LEGACY_RESERVED_ATTRIBUTE_NAME = "SQSLargePayloadSize"; + public static final List RESERVED_ATTRIBUTE_NAMES = Arrays.asList(LEGACY_RESERVED_ATTRIBUTE_NAME, + SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + + public static void checkMessageAttributes(int payloadSizeThreshold, Map messageAttributes) { + int msgAttributesSize = getMsgAttributesSize(messageAttributes); + if (msgAttributesSize > payloadSizeThreshold) { + String errorMessage = "Total size of Message attributes is " + msgAttributesSize + + " bytes which is larger than the threshold of " + payloadSizeThreshold + + " Bytes. Consider including the payload in the message body instead of message attributes."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + + int messageAttributesNum = messageAttributes.size(); + if (messageAttributesNum > SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES) { + String errorMessage = "Number of message attributes [" + messageAttributesNum + + "] exceeds the maximum allowed for large-payload messages [" + + SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES + "]."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(messageAttributes); + + if (largePayloadAttributeName.isPresent()) { + String errorMessage = "Message attribute name " + largePayloadAttributeName.get() + + " is reserved for use by SQS extended client."; + LOG.error(errorMessage); + throw SdkClientException.create(errorMessage); + } + } + + public static Optional getReservedAttributeNameIfPresent(Map msgAttributes) { + String reservedAttributeName = null; + if (msgAttributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) { + reservedAttributeName = SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME; + } else if (msgAttributes.containsKey(LEGACY_RESERVED_ATTRIBUTE_NAME)) { + reservedAttributeName = LEGACY_RESERVED_ATTRIBUTE_NAME; + } + return Optional.ofNullable(reservedAttributeName); + } + + public static String embedS3PointerInReceiptHandle(String receiptHandle, String pointer) { + PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(pointer); + String s3MsgBucketName = s3Pointer.getS3BucketName(); + String s3MsgKey = s3Pointer.getS3Key(); + + return SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + s3MsgBucketName + + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER + + s3MsgKey + SQSExtendedClientConstants.S3_KEY_MARKER + receiptHandle; + } + + public static String getOrigReceiptHandle(String receiptHandle) { + int secondOccurence = receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER, + receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER) + 1); + return receiptHandle.substring(secondOccurence + SQSExtendedClientConstants.S3_KEY_MARKER.length()); + } + + public static boolean isS3ReceiptHandle(String receiptHandle) { + return receiptHandle.contains(SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER) + && receiptHandle.contains(SQSExtendedClientConstants.S3_KEY_MARKER); + } + + public static String getMessagePointerFromModifiedReceiptHandle(String receiptHandle) { + String s3MsgBucketName = getFromReceiptHandleByMarker( + receiptHandle, SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER); + String s3MsgKey = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_KEY_MARKER); + + PayloadS3Pointer payloadS3Pointer = new PayloadS3Pointer(s3MsgBucketName, s3MsgKey); + return payloadS3Pointer.toJson(); + } + + public static boolean isLarge(int payloadSizeThreshold, SendMessageRequest sendMessageRequest) { + int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.messageAttributes()); + long msgBodySize = Util.getStringSizeInBytes(sendMessageRequest.messageBody()); + long totalMsgSize = msgAttributesSize + msgBodySize; + return (totalMsgSize > payloadSizeThreshold); + } + + public static boolean isLarge(int payloadSizeThreshold, SendMessageBatchRequestEntry batchEntry) { + int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes()); + long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody()); + long totalMsgSize = msgAttributesSize + msgBodySize; + return (totalMsgSize > payloadSizeThreshold); + } + + public static Map updateMessageAttributePayloadSize( + Map messageAttributes, Long messageContentSize, + boolean usesLegacyReservedAttributeName) { + Map updatedMessageAttributes = new HashMap<>(messageAttributes); + + // Add a new message attribute as a flag + MessageAttributeValue.Builder messageAttributeValueBuilder = MessageAttributeValue.builder(); + messageAttributeValueBuilder.dataType("Number"); + messageAttributeValueBuilder.stringValue(messageContentSize.toString()); + MessageAttributeValue messageAttributeValue = messageAttributeValueBuilder.build(); + + if (!usesLegacyReservedAttributeName) { + updatedMessageAttributes.put(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, messageAttributeValue); + } else { + updatedMessageAttributes.put(LEGACY_RESERVED_ATTRIBUTE_NAME, messageAttributeValue); + } + return updatedMessageAttributes; + } + + @SuppressWarnings("unchecked") + public static T appendUserAgent( + final T builder, String userAgentName, String userAgentVersion) { + return (T) builder + .overrideConfiguration( + AwsRequestOverrideConfiguration.builder() + .addApiName(ApiName.builder().name(userAgentName) + .version(userAgentVersion).build()) + .build()); + } + + private static String getFromReceiptHandleByMarker(String receiptHandle, String marker) { + int firstOccurence = receiptHandle.indexOf(marker); + int secondOccurence = receiptHandle.indexOf(marker, firstOccurence + 1); + return receiptHandle.substring(firstOccurence + marker.length(), secondOccurence); + } + + private static int getMsgAttributesSize(Map msgAttributes) { + int totalMsgAttributesSize = 0; + for (Map.Entry entry : msgAttributes.entrySet()) { + totalMsgAttributesSize += Util.getStringSizeInBytes(entry.getKey()); + + MessageAttributeValue entryVal = entry.getValue(); + if (entryVal.dataType() != null) { + totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.dataType()); + } + + String stringVal = entryVal.stringValue(); + if (stringVal != null) { + totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.stringValue()); + } + + SdkBytes binaryVal = entryVal.binaryValue(); + if (binaryVal != null) { + totalMsgAttributesSize += binaryVal.asByteArray().length; + } + } + return totalMsgAttributesSize; + } +} diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java new file mode 100644 index 0000000..ab2d322 --- /dev/null +++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java @@ -0,0 +1,182 @@ +package com.amazon.sqs.javamessaging; + +import software.amazon.awssdk.annotations.NotThreadSafe; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.ObjectCannedACL; +import software.amazon.payloadoffloading.PayloadStorageAsyncConfiguration; +import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; + +/** + * Amazon SQS extended client configuration options such as async Amazon S3 client, + * bucket name, and message size threshold for large-payload messages. + */ +@NotThreadSafe +public class ExtendedAsyncClientConfiguration extends PayloadStorageAsyncConfiguration { + + private boolean cleanupS3Payload = true; + private boolean useLegacyReservedAttributeName = true; + private boolean ignorePayloadNotFound = false; + + public ExtendedAsyncClientConfiguration() { + this.setPayloadSizeThreshold(SQSExtendedClientConstants.DEFAULT_MESSAGE_SIZE_THRESHOLD); + } + + public ExtendedAsyncClientConfiguration(ExtendedAsyncClientConfiguration other) { + super(other); + this.cleanupS3Payload = other.doesCleanupS3Payload(); + this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName(); + this.ignorePayloadNotFound = other.ignoresPayloadNotFound(); + } + + /** + * Enables asynchronous support for payload messages. + * @param s3Async + * Amazon S3 client which is going to be used for storing + * payload messages. + * @param s3BucketName + * Name of the bucket which is going to be used for storing + * payload messages. The bucket must be already created and + * configured in s3. + * @param cleanupS3Payload + * If set to true, would handle deleting the S3 object as part + * of deleting the message from SQS queue. Otherwise, would not + * attempt to delete the object from S3. If opted to not delete S3 + * objects its the responsibility to the message producer to handle + * the clean up appropriately. + */ + public void setPayloadSupportEnabled(S3AsyncClient s3Async, String s3BucketName, boolean cleanupS3Payload) { + setPayloadSupportEnabled(s3Async, s3BucketName); + this.cleanupS3Payload = cleanupS3Payload; + } + + /** + * Enables asynchronous support for payload messages. + * @param s3Async + * Amazon S3 client which is going to be used for storing + * payload messages. + * @param s3BucketName + * Name of the bucket which is going to be used for storing + * payload messages. The bucket must be already created and + * configured in s3. + * @param cleanupS3Payload + * If set to true, would handle deleting the S3 object as part + * of deleting the message from SQS queue. Otherwise, would not + * attempt to delete the object from S3. If opted to not delete S3 + * objects its the responsibility to the message producer to handle + * the clean up appropriately. + */ + public ExtendedAsyncClientConfiguration withPayloadSupportEnabled( + S3AsyncClient s3Async, String s3BucketName, boolean cleanupS3Payload) { + setPayloadSupportEnabled(s3Async, s3BucketName, cleanupS3Payload); + return this; + } + + @Override + public ExtendedAsyncClientConfiguration withPayloadSupportEnabled(S3AsyncClient s3Async, String s3BucketName) { + this.setPayloadSupportEnabled(s3Async, s3BucketName); + return this; + } + + /** + * Disables the utilization legacy payload attribute name when sending messages. + */ + public void setLegacyReservedAttributeNameDisabled() { + this.useLegacyReservedAttributeName = false; + } + + /** + * Disables the utilization legacy payload attribute name when sending messages. + */ + public ExtendedAsyncClientConfiguration withLegacyReservedAttributeNameDisabled() { + setLegacyReservedAttributeNameDisabled(); + return this; + } + + /** + * Sets whether or not messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. + * + * @param ignorePayloadNotFound + * Whether or not messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. Default: false + */ + public void setIgnorePayloadNotFound(boolean ignorePayloadNotFound) { + this.ignorePayloadNotFound = ignorePayloadNotFound; + } + + /** + * Sets whether or not messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. + * + * @param ignorePayloadNotFound + * Whether or not messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. Default: false + * @return the updated ExtendedAsyncClientConfiguration object. + */ + public ExtendedAsyncClientConfiguration withIgnorePayloadNotFound(boolean ignorePayloadNotFound) { + setIgnorePayloadNotFound(ignorePayloadNotFound); + return this; + } + + /** + * Checks whether or not clean up large objects in S3 is enabled. + * + * @return True if clean up is enabled when deleting the concerning SQS message. + * Default: true + */ + public boolean doesCleanupS3Payload() { + return cleanupS3Payload; + } + + /** + * Checks whether or not the configuration uses the legacy reserved attribute name. + * + * @return True if legacy reserved attribute name is used. + * Default: true + */ + + public boolean usesLegacyReservedAttributeName() { + return useLegacyReservedAttributeName; + } + + /** + * Checks whether or not messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. + * + * @return True if messages should be removed from Amazon SQS + * when payloads are not found in Amazon S3. Default: false + */ + public boolean ignoresPayloadNotFound() { + return ignorePayloadNotFound; + } + + @Override + public ExtendedAsyncClientConfiguration withAlwaysThroughS3(boolean alwaysThroughS3) { + setAlwaysThroughS3(alwaysThroughS3); + return this; + } + + @Override + public ExtendedAsyncClientConfiguration withObjectCannedACL(ObjectCannedACL objectCannedACL) { + this.setObjectCannedACL(objectCannedACL); + return this; + } + + @Override + public ExtendedAsyncClientConfiguration withPayloadSizeThreshold(int payloadSizeThreshold) { + this.setPayloadSizeThreshold(payloadSizeThreshold); + return this; + } + + @Override + public ExtendedAsyncClientConfiguration withPayloadSupportDisabled() { + this.setPayloadSupportDisabled(); + return this; + } + + @Override + public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncryptionStrategy serverSideEncryption) { + this.setServerSideEncryptionStrategy(serverSideEncryption); + return this; + } +} \ No newline at end of file diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index d2e7798..49dcf80 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -15,6 +15,21 @@ package com.amazon.sqs.javamessaging; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -22,12 +37,10 @@ import java.util.UUID; import java.util.stream.Collectors; import java.util.stream.IntStream; - import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; - import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; import software.amazon.awssdk.core.ApiName; import software.amazon.awssdk.core.ResponseInputStream; @@ -56,22 +69,6 @@ import software.amazon.payloadoffloading.ServerSideEncryptionFactory; import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_NAME; -import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClient.USER_AGENT_VERSION; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotEquals; -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doThrow; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.when; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verifyZeroInteractions; -import static org.mockito.Mockito.times; - /** * Tests the AmazonSQSExtendedClient class. */ @@ -266,7 +263,7 @@ public void testSendLargeMessageWithDefaultConfigThenLegacyReservedAttributeName verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); - Assert.assertTrue(attributes.containsKey(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME)); + Assert.assertTrue(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); Assert.assertFalse(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)); } @@ -282,7 +279,7 @@ public void testSendLargeMessageWithGenericReservedAttributeNameConfigThenGeneri Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); Assert.assertTrue(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)); - Assert.assertFalse(attributes.containsKey(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME)); + Assert.assertFalse(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); } @Test @@ -366,7 +363,7 @@ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessag @Test public void testReceiveMessage_when_MessageIsLarge_legacyReservedAttributeUsed() throws Exception { - testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME); + testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME); } @Test @@ -390,7 +387,7 @@ public void testReceiveMessage_when_MessageIsSmall() throws Exception { assertEquals(expectedMessage, actualMessage.body()); Assert.assertTrue(actualMessage.messageAttributes().containsKey(expectedMessageAttributeName)); - Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClient.RESERVED_ATTRIBUTE_NAMES)); + Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES)); verifyZeroInteractions(mockS3); } @@ -484,8 +481,8 @@ public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() { verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); - assertEquals("Number", attributes.get(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType()); - assertEquals(messageLength, (int) Integer.parseInt(attributes.get(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue())); + assertEquals("Number", attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType()); + assertEquals(messageLength, (int) Integer.parseInt(attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue())); } @Test @@ -637,7 +634,7 @@ private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName Message actualMessage = actualReceiveMessageResponse.messages().get(0); assertEquals(expectedMessage, actualMessage.body()); - Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClient.RESERVED_ATTRIBUTE_NAMES)); + Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES)); verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class)); } From 747aee02632d21953579d7c35e27c834d11c2546 Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Tue, 25 Oct 2022 17:00:20 -0700 Subject: [PATCH 2/4] Add UTs, and fix a couple bugs --- .../AmazonSQSExtendedAsyncClient.java | 3 +- .../AmazonSQSExtendedAsyncClientTest.java | 667 ++++++++++++++++++ .../ExtendedAsyncClientConfigurationTest.java | 83 +++ 3 files changed, 752 insertions(+), 1 deletion(-) create mode 100644 src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java create mode 100644 src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index d9a5992..5a09713 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -210,7 +210,7 @@ public CompletableFuture receiveMessage(ReceiveMessageRe // For each received message check if they are stored in S3. Optional largePayloadAttributeName = getReservedAttributeNameIfPresent( message.messageAttributes()); - if (largePayloadAttributeName.isPresent()) { + if (!largePayloadAttributeName.isPresent()) { // Not S3 modifiedMessageFutures.add(CompletableFuture.completedFuture(messageBuilder.build())); } else { @@ -361,6 +361,7 @@ public CompletableFuture sendMessageBatch( if (clientConfiguration.isAlwaysThroughS3() || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) { batchEntryFutures.add(storeMessageInS3(entry)); + hasS3Entries = true; } else { batchEntryFutures.add(CompletableFuture.completedFuture(entry)); } diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java new file mode 100644 index 0000000..798dfda --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java @@ -0,0 +1,667 @@ +package com.amazon.sqs.javamessaging; + +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedAsyncClient.USER_AGENT_NAME; +import static com.amazon.sqs.javamessaging.AmazonSQSExtendedAsyncClient.USER_AGENT_VERSION; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; +import static org.mockito.Mockito.when; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration; +import software.amazon.awssdk.core.ApiName; +import software.amazon.awssdk.core.ResponseBytes; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.DeleteObjectRequest; +import software.amazon.awssdk.services.s3.model.DeleteObjectResponse; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.ObjectCannedACL; +import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.sqs.SqsAsyncClient; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest; +import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse; +import software.amazon.awssdk.services.sqs.model.Message; +import software.amazon.awssdk.services.sqs.model.MessageAttributeValue; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest; +import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry; +import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse; +import software.amazon.awssdk.services.sqs.model.SendMessageRequest; +import software.amazon.awssdk.services.sqs.model.SendMessageResponse; +import software.amazon.awssdk.utils.ImmutableMap; +import software.amazon.payloadoffloading.PayloadS3Pointer; +import software.amazon.payloadoffloading.ServerSideEncryptionFactory; +import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; + +public class AmazonSQSExtendedAsyncClientTest { + + private SqsAsyncClient extendedSqsWithDefaultConfig; + private SqsAsyncClient extendedSqsWithCustomKMS; + private SqsAsyncClient extendedSqsWithDefaultKMS; + private SqsAsyncClient extendedSqsWithGenericReservedAttributeName; + private SqsAsyncClient mockSqsBackend; + private S3AsyncClient mockS3; + private static final String S3_BUCKET_NAME = "test-bucket-name"; + private static final String SQS_QUEUE_URL = "test-queue-url"; + private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id"; + + private static final int LESS_THAN_SQS_SIZE_LIMIT = 3; + private static final int SQS_SIZE_LIMIT = 262144; + private static final int MORE_THAN_SQS_SIZE_LIMIT = SQS_SIZE_LIMIT + 1; + private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY = ServerSideEncryptionFactory.customerKey(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID); + private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_DEFAULT_STRATEGY = ServerSideEncryptionFactory.awsManagedCmk(); + + // should be > 1 and << SQS_SIZE_LIMIT + private static final int ARBITRARY_SMALLER_THRESHOLD = 500; + + @Before + public void setupClients() { + mockS3 = mock(S3AsyncClient.class); + mockSqsBackend = mock(SqsAsyncClient.class); + when(mockS3.putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class))).thenReturn( + CompletableFuture.completedFuture(null)); + when(mockS3.deleteObject(isA(DeleteObjectRequest.class))).thenReturn( + CompletableFuture.completedFuture(DeleteObjectResponse.builder().build())); + when(mockSqsBackend.sendMessage(isA(SendMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(SendMessageResponse.builder().build())); + when(mockSqsBackend.sendMessageBatch(isA(SendMessageBatchRequest.class))).thenReturn( + CompletableFuture.completedFuture(SendMessageBatchResponse.builder().build())); + when(mockSqsBackend.deleteMessage(isA(DeleteMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(DeleteMessageResponse.builder().build())); + when(mockSqsBackend.deleteMessageBatch(isA(DeleteMessageBatchRequest.class))).thenReturn( + CompletableFuture.completedFuture(DeleteMessageBatchResponse.builder().build())); + + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + + ExtendedAsyncClientConfiguration extendedClientConfigurationWithCustomKMS = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY); + + ExtendedAsyncClientConfiguration extendedClientConfigurationWithDefaultKMS = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME) + .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_DEFAULT_STRATEGY); + + ExtendedAsyncClientConfiguration extendedClientConfigurationWithGenericReservedAttributeName = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withLegacyReservedAttributeNameDisabled(); + + extendedSqsWithDefaultConfig = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + extendedSqsWithCustomKMS = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithCustomKMS)); + extendedSqsWithDefaultKMS = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithDefaultKMS)); + extendedSqsWithGenericReservedAttributeName = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithGenericReservedAttributeName)); + } + + @Test + public void testWhenSendMessageWithLargePayloadSupportDisabledThenS3IsNotUsedAndSqsBackendIsResponsibleToFailItWithDeprecatedMethod() { + int messageLength = MORE_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportDisabled(); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageBody(messageBody) + .overrideConfiguration( + AwsRequestOverrideConfiguration.builder() + .addApiName(ApiName.builder().name(USER_AGENT_NAME).version(USER_AGENT_VERSION).build()) + .build()) + .build(); + sqsExtended.sendMessage(messageRequest); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + + verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + verify(mockSqsBackend).sendMessage(argumentCaptor.capture()); + assertEquals(messageRequest.queueUrl(), argumentCaptor.getValue().queueUrl()); + assertEquals(messageRequest.messageBody(), argumentCaptor.getValue().messageBody()); + assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).name(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).name()); + assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).version(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).version()); + } + + @Test + public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStillStoredInS3WithDeprecatedMethod() { + int messageLength = LESS_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonoredWithDeprecatedMethod() { + int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2; + String messageBody = generateStringWithLength(messageLength); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withPayloadSizeThreshold(ARBITRARY_SMALLER_THRESHOLD); + + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequestWithDeprecatedMethod() { + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(ReceiveMessageResponse.builder().build())); + + ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build(); + ReceiveMessageRequest expectedRequest = ReceiveMessageRequest.builder().build(); + + sqsExtended.receiveMessage(messageRequest).join(); + assertEquals(expectedRequest, messageRequest); + + sqsExtended.receiveMessage(messageRequest).join(); + assertEquals(expectedRequest, messageRequest); + } + + @Test + public void testWhenSendLargeMessageThenPayloadIsStoredInS3() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenSendLargeMessage_WithoutKMS_ThenPayloadIsStoredInS3AndKMSKeyIdIsNotUsed() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture()); + + Assert.assertNull(putObjectRequestArgumentCaptor.getValue().serverSideEncryption()); + assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME); + } + + @Test + public void testWhenSendLargeMessage_WithCustomKMS_ThenPayloadIsStoredInS3AndCorrectKMSKeyIdIsNotUsed() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithCustomKMS.sendMessage(messageRequest).join(); + + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture()); + + assertEquals(putObjectRequestArgumentCaptor.getValue().ssekmsKeyId(), S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID); + assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME); + } + + @Test + public void testWhenSendLargeMessage_WithDefaultKMS_ThenPayloadIsStoredInS3AndCorrectKMSKeyIdIsNotUsed() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultKMS.sendMessage(messageRequest).join(); + + ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); + ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class); + verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture()); + + Assert.assertTrue(putObjectRequestArgumentCaptor.getValue().serverSideEncryption() != null && + putObjectRequestArgumentCaptor.getValue().ssekmsKeyId() == null); + assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME); + } + + @Test + public void testSendLargeMessageWithDefaultConfigThenLegacyReservedAttributeNameIsUsed(){ + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); + Assert.assertTrue(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + Assert.assertFalse(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)); + + } + + @Test + public void testSendLargeMessageWithGenericReservedAttributeNameConfigThenGenericReservedAttributeNameIsUsed(){ + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithGenericReservedAttributeName.sendMessage(messageRequest).join(); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); + Assert.assertTrue(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)); + Assert.assertFalse(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME)); + } + + @Test + public void testWhenSendSmallMessageThenS3IsNotUsed() { + String messageBody = generateStringWithLength(SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenSendMessageWithLargePayloadSupportDisabledThenS3IsNotUsedAndSqsBackendIsResponsibleToFailIt() { + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportDisabled(); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder() + .queueUrl(SQS_QUEUE_URL) + .messageBody(messageBody) + .overrideConfiguration( + AwsRequestOverrideConfiguration.builder() + .addApiName(ApiName.builder().name(USER_AGENT_NAME).version(USER_AGENT_VERSION).build()) + .build()) + .build(); + sqsExtended.sendMessage(messageRequest).join(); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + + verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + verify(mockSqsBackend).sendMessage(argumentCaptor.capture()); + assertEquals(messageRequest.queueUrl(), argumentCaptor.getValue().queueUrl()); + assertEquals(messageRequest.messageBody(), argumentCaptor.getValue().messageBody()); + assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).name(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).name()); + assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).version(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).version()); + } + + @Test + public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStillStoredInS3() { + String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored() { + int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2; + String messageBody = generateStringWithLength(messageLength); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withPayloadSizeThreshold(ARBITRARY_SMALLER_THRESHOLD); + + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() { + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(ReceiveMessageResponse.builder().build())); + + ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build(); + + ReceiveMessageRequest expectedRequest = ReceiveMessageRequest.builder().build(); + + extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join(); + assertEquals(expectedRequest, messageRequest); + + extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join(); + assertEquals(expectedRequest, messageRequest); + } + + @Test + public void testReceiveMessage_when_MessageIsLarge_legacyReservedAttributeUsed() throws Exception { + testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME); + } + + @Test + public void testReceiveMessage_when_MessageIsLarge_ReservedAttributeUsed() throws Exception { + testReceiveMessage_when_MessageIsLarge(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + } + + @Test + public void testReceiveMessage_when_MessageIsSmall() throws Exception { + String expectedMessageAttributeName = "AnyMessageAttribute"; + String expectedMessage = "SmallMessage"; + Message message = Message.builder() + .messageAttributes(ImmutableMap.of(expectedMessageAttributeName, MessageAttributeValue.builder().build())) + .body(expectedMessage) + .build(); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(ReceiveMessageResponse.builder().messages(message).build())); + + ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build(); + ReceiveMessageResponse actualReceiveMessageResponse = extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join(); + Message actualMessage = actualReceiveMessageResponse.messages().get(0); + + assertEquals(expectedMessage, actualMessage.body()); + Assert.assertTrue(actualMessage.messageAttributes().containsKey(expectedMessageAttributeName)); + Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES)); + verifyNoInteractions(mockS3); + } + + @Test + public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { + // This creates 10 messages, out of which only two are below the threshold (100K and 200K), + // and the other 8 are above the threshold + + int[] messageLengthForCounter = new int[] { + 100_000, + 300_000, + 400_000, + 500_000, + 600_000, + 700_000, + 800_000, + 900_000, + 200_000, + 1000_000 + }; + + List batchEntries = new ArrayList(); + for (int i = 0; i < 10; i++) { + int messageLength = messageLengthForCounter[i]; + String messageBody = generateStringWithLength(messageLength); + SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() + .id("entry_" + i) + .messageBody(messageBody) + .build(); + batchEntries.add(entry); + } + + SendMessageBatchRequest + batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); + extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest).join(); + + // There should be 8 puts for the 8 messages above the threshold + verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class)); + } + + @Test + public void testWhenMessageBatchIsLargeS3PointerIsCorrectlySentToSQSAndNotOriginalMessage() { + String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true); + + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + List batchEntries = new ArrayList(); + for (int i = 0; i < 10; i++) { + SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder() + .id("entry_" + i) + .messageBody(messageBody) + .build(); + batchEntries.add(entry); + } + SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build(); + + sqsExtended.sendMessageBatch(batchRequest).join(); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class); + verify(mockSqsBackend).sendMessageBatch(sendMessageRequestCaptor.capture()); + + for (SendMessageBatchRequestEntry entry : sendMessageRequestCaptor.getValue().entries()) { + assertNotEquals(messageBody, entry.messageBody()); + } + } + + @Test + public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() { + String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); + Assert.assertTrue(attributes.isEmpty()); + } + + @Test + public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() { + int messageLength = MORE_THAN_SQS_SIZE_LIMIT; + String messageBody = generateStringWithLength(messageLength); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + extendedSqsWithDefaultConfig.sendMessage(messageRequest).join(); + + ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class); + verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture()); + + Map attributes = sendMessageRequestCaptor.getValue().messageAttributes(); + assertEquals("Number", attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType()); + assertEquals(messageLength, (int) Integer.parseInt(attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue())); + } + + @Test + public void testDefaultExtendedClientDeletesSmallMessage() { + // given + String receiptHandle = UUID.randomUUID().toString(); + DeleteMessageRequest + deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(receiptHandle).build(); + + // when + extendedSqsWithDefaultConfig.deleteMessage(deleteRequest).join(); + + // then + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class); + verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture()); + assertEquals(receiptHandle, deleteRequestCaptor.getValue().receiptHandle()); + verifyNoInteractions(mockS3); + } + + @Test + public void testDefaultExtendedClientDeletesObjectS3UponMessageDelete() { + // given + String randomS3Key = UUID.randomUUID().toString(); + String originalReceiptHandle = UUID.randomUUID().toString(); + String largeMessageReceiptHandle = getLargeReceiptHandle(randomS3Key, originalReceiptHandle); + DeleteMessageRequest deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(largeMessageReceiptHandle).build(); + + // when + extendedSqsWithDefaultConfig.deleteMessage(deleteRequest).join(); + + // then + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class); + verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture()); + assertEquals(originalReceiptHandle, deleteRequestCaptor.getValue().receiptHandle()); + DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(S3_BUCKET_NAME).key(randomS3Key).build(); + verify(mockS3).deleteObject(eq(deleteObjectRequest)); + } + + @Test + public void testExtendedClientConfiguredDoesNotDeleteObjectFromS3UponDelete() { + // given + String randomS3Key = UUID.randomUUID().toString(); + String originalReceiptHandle = UUID.randomUUID().toString(); + String largeMessageReceiptHandle = getLargeReceiptHandle(randomS3Key, originalReceiptHandle); + DeleteMessageRequest deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(largeMessageReceiptHandle).build(); + + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME, false); + + SqsAsyncClient extendedSqs = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + // when + extendedSqs.deleteMessage(deleteRequest).join(); + + // then + ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class); + verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture()); + assertEquals(originalReceiptHandle, deleteRequestCaptor.getValue().receiptHandle()); + verifyNoInteractions(mockS3); + } + + @Test + public void testExtendedClientConfiguredDoesNotDeletesObjectsFromS3UponDeleteBatch() { + // given + int batchSize = 10; + List originalReceiptHandles = IntStream.range(0, batchSize) + .mapToObj(i -> UUID.randomUUID().toString()) + .collect(Collectors.toList()); + DeleteMessageBatchRequest deleteBatchRequest = generateLargeDeleteBatchRequest(originalReceiptHandles); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME, false); + SqsAsyncClient extendedSqs = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + // when + extendedSqs.deleteMessageBatch(deleteBatchRequest).join(); + + // then + ArgumentCaptor deleteBatchRequestCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + verify(mockSqsBackend, times(1)).deleteMessageBatch(deleteBatchRequestCaptor.capture()); + DeleteMessageBatchRequest request = deleteBatchRequestCaptor.getValue(); + assertEquals(originalReceiptHandles.size(), request.entries().size()); + IntStream.range(0, originalReceiptHandles.size()).forEach(i -> assertEquals( + originalReceiptHandles.get(i), + request.entries().get(i).receiptHandle())); + verifyNoInteractions(mockS3); + } + + @Test + public void testDefaultExtendedClientDeletesObjectsFromS3UponDeleteBatch() { + // given + int batchSize = 10; + List originalReceiptHandles = IntStream.range(0, batchSize) + .mapToObj(i -> UUID.randomUUID().toString()) + .collect(Collectors.toList()); + DeleteMessageBatchRequest deleteBatchRequest = generateLargeDeleteBatchRequest(originalReceiptHandles); + + // when + extendedSqsWithDefaultConfig.deleteMessageBatch(deleteBatchRequest).join(); + + // then + ArgumentCaptor deleteBatchRequestCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class); + verify(mockSqsBackend, times(1)).deleteMessageBatch(deleteBatchRequestCaptor.capture()); + DeleteMessageBatchRequest request = deleteBatchRequestCaptor.getValue(); + assertEquals(originalReceiptHandles.size(), request.entries().size()); + IntStream.range(0, originalReceiptHandles.size()).forEach(i -> assertEquals( + originalReceiptHandles.get(i), + request.entries().get(i).receiptHandle())); + verify(mockS3, times(batchSize)).deleteObject(any(DeleteObjectRequest.class)); + } + + @Test + public void testWhenSendMessageWIthCannedAccessControlListDefined() { + ObjectCannedACL expected = ObjectCannedACL.BUCKET_OWNER_FULL_CONTROL; + String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration() + .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withObjectCannedACL(expected); + SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration)); + + SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build(); + sqsExtended.sendMessage(messageRequest).join(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(PutObjectRequest.class); + + verify(mockS3).putObject(captor.capture(), any(AsyncRequestBody.class)); + + assertEquals(expected, captor.getValue().acl()); + } + + private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName) throws Exception { + String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "S3Key").toJson(); + Message message = Message.builder() + .messageAttributes(ImmutableMap.of(reservedAttributeName, MessageAttributeValue.builder().build())) + .body(pointer) + .build(); + String expectedMessage = "LargeMessage"; + GetObjectRequest getObjectRequest = GetObjectRequest.builder() + .bucket(S3_BUCKET_NAME) + .key("S3Key") + .build(); + + ResponseBytes s3Object = ResponseBytes.fromByteArray( + GetObjectResponse.builder().build(), + expectedMessage.getBytes(StandardCharsets.UTF_8)); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn( + CompletableFuture.completedFuture(ReceiveMessageResponse.builder().messages(message).build())); + when(mockS3.getObject(isA(GetObjectRequest.class), isA(AsyncResponseTransformer.class))).thenReturn( + CompletableFuture.completedFuture(s3Object)); + + ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build(); + ReceiveMessageResponse actualReceiveMessageResponse = extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join(); + Message actualMessage = actualReceiveMessageResponse.messages().get(0); + + assertEquals(expectedMessage, actualMessage.body()); + Assert.assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES)); + verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class), isA(AsyncResponseTransformer.class)); + } + + private DeleteMessageBatchRequest generateLargeDeleteBatchRequest(List originalReceiptHandles) { + List deleteEntries = IntStream.range(0, originalReceiptHandles.size()) + .mapToObj(i -> DeleteMessageBatchRequestEntry.builder() + .id(Integer.toString(i)) + .receiptHandle(getSampleLargeReceiptHandle(originalReceiptHandles.get(i))) + .build()) + .collect(Collectors.toList()); + + return DeleteMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(deleteEntries).build(); + } + + private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle) { + return SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + S3_BUCKET_NAME + + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER + + s3Key + SQSExtendedClientConstants.S3_KEY_MARKER + originalReceiptHandle; + } + + private String getSampleLargeReceiptHandle(String originalReceiptHandle) { + return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle); + } + + private String generateStringWithLength(int messageLength) { + char[] charArray = new char[messageLength]; + Arrays.fill(charArray, 'x'); + return new String(charArray); + } +} diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java new file mode 100644 index 0000000..bd615a4 --- /dev/null +++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java @@ -0,0 +1,83 @@ +package com.amazon.sqs.javamessaging; + +import static org.mockito.Mockito.mock; + +import org.junit.Assert; +import org.junit.Test; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.payloadoffloading.ServerSideEncryptionFactory; +import software.amazon.payloadoffloading.ServerSideEncryptionStrategy; + +/** + * Tests the ExtendedAsyncClientConfiguration class. + */ +public class ExtendedAsyncClientConfigurationTest { + + private static String s3BucketName = "test-bucket-name"; + private static String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id"; + private static ServerSideEncryptionStrategy serverSideEncryptionStrategy = ServerSideEncryptionFactory.customerKey(s3ServerSideEncryptionKMSKeyId); + + @Test + public void testCopyConstructor() { + S3AsyncClient s3 = mock(S3AsyncClient.class); + + boolean alwaysThroughS3 = true; + int messageSizeThreshold = 500; + boolean doesCleanupS3Payload = false; + + ExtendedAsyncClientConfiguration extendedClientConfig = new ExtendedAsyncClientConfiguration(); + + extendedClientConfig.withPayloadSupportEnabled(s3, s3BucketName, doesCleanupS3Payload) + .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(messageSizeThreshold) + .withServerSideEncryption(serverSideEncryptionStrategy); + + ExtendedAsyncClientConfiguration newExtendedClientConfig = new ExtendedAsyncClientConfiguration(extendedClientConfig); + + Assert.assertEquals(s3, newExtendedClientConfig.getS3AsyncClient()); + Assert.assertEquals(s3BucketName, newExtendedClientConfig.getS3BucketName()); + Assert.assertEquals(serverSideEncryptionStrategy, newExtendedClientConfig.getServerSideEncryptionStrategy()); + Assert.assertTrue(newExtendedClientConfig.isPayloadSupportEnabled()); + Assert.assertEquals(doesCleanupS3Payload, newExtendedClientConfig.doesCleanupS3Payload()); + Assert.assertEquals(alwaysThroughS3, newExtendedClientConfig.isAlwaysThroughS3()); + Assert.assertEquals(messageSizeThreshold, newExtendedClientConfig.getPayloadSizeThreshold()); + + Assert.assertNotSame(newExtendedClientConfig, extendedClientConfig); + } + + @Test + public void testLargePayloadSupportEnabledWithDefaultDeleteFromS3Config() { + S3AsyncClient s3 = mock(S3AsyncClient.class); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName); + + Assert.assertTrue(extendedClientConfiguration.isPayloadSupportEnabled()); + Assert.assertTrue(extendedClientConfiguration.doesCleanupS3Payload()); + Assert.assertNotNull(extendedClientConfiguration.getS3AsyncClient()); + Assert.assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); + + } + + @Test + public void testLargePayloadSupportEnabledWithDeleteFromS3Enabled() { + S3AsyncClient s3 = mock(S3AsyncClient.class); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName, true); + + Assert.assertTrue(extendedClientConfiguration.isPayloadSupportEnabled()); + Assert.assertTrue(extendedClientConfiguration.doesCleanupS3Payload()); + Assert.assertNotNull(extendedClientConfiguration.getS3AsyncClient()); + Assert.assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); + } + + @Test + public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() { + S3AsyncClient s3 = mock(S3AsyncClient.class); + ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration(); + extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName, false); + + Assert.assertTrue(extendedClientConfiguration.isPayloadSupportEnabled()); + Assert.assertFalse(extendedClientConfiguration.doesCleanupS3Payload()); + Assert.assertNotNull(extendedClientConfiguration.getS3AsyncClient()); + Assert.assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName()); + } +} From 9d5a2352d110b604f22486c41f4349c9a31665a8 Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Tue, 25 Oct 2022 17:23:21 -0700 Subject: [PATCH 3/4] Throw argument validation exceptions immediately instead of deferring. --- .../AmazonSQSExtendedAsyncClient.java | 63 +++++++------------ 1 file changed, 22 insertions(+), 41 deletions(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 5a09713..9f80e87 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -51,9 +51,14 @@ /** * Amazon SQS Extended Async Client extends the functionality of Amazon Async SQS - * client. All service calls made using this client are asynchronous, and will return + * client. + * + *

+ * All service calls made using this client are asynchronous, and will return * immediately with a {@link CompletableFuture} that completes when the operation - * completes or when an exception is thrown. + * completes or when an exception is thrown. Argument validation exceptions are thrown + * immediately, and not through the future. + *

* *

* The Amazon SQS extended client enables sending and receiving large messages @@ -85,7 +90,9 @@ public class AmazonSQSExtendedAsyncClient extends AmazonSQSExtendedAsyncClientBa *

* All service calls made using this client are asynchronous, and will return * immediately with a {@link CompletableFuture} that completes when the operation - * completes or when an exception is thrown. + * completes or when an exception is thrown. Argument validation exceptions are thrown + * immediately, and not through the future. + *

* * @param sqsClient * The Amazon SQS async client to use to connect to Amazon SQS. @@ -102,7 +109,9 @@ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient) { *

* All service calls made using this client are asynchronous, and will return * immediately with a {@link CompletableFuture} that completes when the operation - * completes or when an exception is thrown. + * completes or when an exception is thrown. Argument validation exceptions are thrown + * immediately, and not through the future. + *

* * @param sqsClient * The Amazon SQS async client to use to connect to Amazon SQS. @@ -130,9 +139,7 @@ public CompletableFuture sendMessage(SendMessageRequest sen if (sendMessageRequest == null) { String errorMessage = "sendMessageRequest cannot be null."; LOG.error(errorMessage); - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(SdkClientException.create(errorMessage)); - return futureEx; + throw SdkClientException.create(errorMessage); } SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder(); @@ -145,19 +152,11 @@ public CompletableFuture sendMessage(SendMessageRequest sen if (StringUtils.isEmpty(sendMessageRequest.messageBody())) { String errorMessage = "messageBody cannot be null or empty."; LOG.error(errorMessage); - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(SdkClientException.create(errorMessage)); - return futureEx; + throw SdkClientException.create(errorMessage); } //Check message attributes for ExtendedClient related constraints - try { - checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); - } catch (SdkClientException e) { - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(e); - return futureEx; - } + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes()); if (clientConfiguration.isAlwaysThroughS3() || isLarge(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest)) { @@ -178,10 +177,7 @@ public CompletableFuture receiveMessage(ReceiveMessageRe if (receiveMessageRequest == null) { String errorMessage = "receiveMessageRequest cannot be null."; LOG.error(errorMessage); - - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(SdkClientException.create(errorMessage)); - return future; + throw SdkClientException.create(errorMessage); } ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder(); @@ -266,10 +262,7 @@ public CompletableFuture deleteMessage(DeleteMessageReque if (deleteMessageRequest == null) { String errorMessage = "deleteMessageRequest cannot be null."; LOG.error(errorMessage); - - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(SdkClientException.create(errorMessage)); - return future; + throw SdkClientException.create(errorMessage); } DeleteMessageRequest.Builder deleteMessageRequestBuilder = deleteMessageRequest.toBuilder(); @@ -332,9 +325,7 @@ public CompletableFuture sendMessageBatch( if (sendMessageBatchRequestIn == null) { String errorMessage = "sendMessageBatchRequest cannot be null."; LOG.error(errorMessage); - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(SdkClientException.create(errorMessage)); - return futureEx; + throw SdkClientException.create(errorMessage); } SendMessageBatchRequest.Builder sendMessageBatchRequestBuilder = sendMessageBatchRequestIn.toBuilder(); @@ -350,13 +341,7 @@ public CompletableFuture sendMessageBatch( boolean hasS3Entries = false; for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) { //Check message attributes for ExtendedClient related constraints - try { - checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); - } catch (SdkClientException e) { - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(e); - return futureEx; - } + checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes()); if (clientConfiguration.isAlwaysThroughS3() || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) { @@ -394,9 +379,7 @@ public CompletableFuture deleteMessageBatch( if (deleteMessageBatchRequest == null) { String errorMessage = "deleteMessageBatchRequest cannot be null."; LOG.error(errorMessage); - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(SdkClientException.create(errorMessage)); - return futureEx; + throw SdkClientException.create(errorMessage); } DeleteMessageBatchRequest.Builder deleteMessageBatchRequestBuilder = deleteMessageBatchRequest.toBuilder(); @@ -460,9 +443,7 @@ public CompletableFuture purgeQueue(PurgeQueueRequest purgeQ if (purgeQueueRequest == null) { String errorMessage = "purgeQueueRequest cannot be null."; LOG.error(errorMessage); - CompletableFuture futureEx = new CompletableFuture<>(); - futureEx.completeExceptionally(SdkClientException.create(errorMessage)); - return futureEx; + throw SdkClientException.create(errorMessage); } PurgeQueueRequest.Builder purgeQueueRequestBuilder = purgeQueueRequest.toBuilder(); From 5e3a5f7799491c0b9c722c09c99796258e8d5d21 Mon Sep 17 00:00:00 2001 From: Richard Smith Date: Thu, 27 Oct 2022 15:38:46 -0700 Subject: [PATCH 4/4] Remove unused local variable. Check for empty messages list as special case. --- .../sqs/javamessaging/AmazonSQSExtendedAsyncClient.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java index 9f80e87..e9faf1d 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java @@ -196,9 +196,13 @@ public CompletableFuture receiveMessage(ReceiveMessageRe return super.receiveMessage(receiveMessageRequest) .thenCompose(receiveMessageResponse -> { - ReceiveMessageResponse.Builder receiveMessageResponseBuilder = receiveMessageResponse.toBuilder(); - List messages = receiveMessageResponse.messages(); + + // Check for no messages. If so, no need to process further. + if (messages.isEmpty()) { + return CompletableFuture.completedFuture(messages); + } + List> modifiedMessageFutures = new ArrayList<>(messages.size()); for (Message message : messages) { Message.Builder messageBuilder = message.toBuilder();