An extension to the Amazon SQS client that enables sending and receiving messages up to 2GB via Amazon S3.
@@ -55,7 +55,7 @@
software.amazon.payloadoffloading
payloadoffloading-common
- 2.1.3
+ 2.2.0
diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java
new file mode 100644
index 0000000..02404a2
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClient.java
@@ -0,0 +1,511 @@
+package com.amazon.sqs.javamessaging;
+
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import software.amazon.awssdk.awscore.AwsRequest;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.core.util.VersionInfo;
+import software.amazon.awssdk.services.sqs.SqsAsyncClient;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchResponse;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse;
+import software.amazon.awssdk.services.sqs.model.Message;
+import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
+import software.amazon.awssdk.services.sqs.model.PurgeQueueRequest;
+import software.amazon.awssdk.services.sqs.model.PurgeQueueResponse;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
+import software.amazon.awssdk.utils.StringUtils;
+import software.amazon.payloadoffloading.PayloadStoreAsync;
+import software.amazon.payloadoffloading.S3AsyncDao;
+import software.amazon.payloadoffloading.S3BackedPayloadStoreAsync;
+import software.amazon.payloadoffloading.Util;
+
+/**
+ * Amazon SQS Extended Async Client extends the functionality of Amazon Async SQS
+ * client.
+ *
+ *
+ * All service calls made using this client are asynchronous, and will return
+ * immediately with a {@link CompletableFuture} that completes when the operation
+ * completes or when an exception is thrown. Argument validation exceptions are thrown
+ * immediately, and not through the future.
+ *
+ *
+ *
+ * The Amazon SQS extended client enables sending and receiving large messages
+ * via Amazon S3. You can use this library to:
+ *
+ *
+ *
+ * - Specify whether messages are always stored in Amazon S3 or only when a
+ * message size exceeds 256 KB.
+ * - Send a message that references a single message object stored in an
+ * Amazon S3 bucket.
+ * - Get the corresponding message object from an Amazon S3 bucket.
+ * - Delete the corresponding message object from an Amazon S3 bucket.
+ *
+ */
+public class AmazonSQSExtendedAsyncClient extends AmazonSQSExtendedAsyncClientBase implements SqsAsyncClient {
+ static final String USER_AGENT_NAME = AmazonSQSExtendedAsyncClient.class.getSimpleName();
+ static final String USER_AGENT_VERSION = VersionInfo.SDK_VERSION;
+
+ private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedAsyncClient.class);
+ private ExtendedAsyncClientConfiguration clientConfiguration;
+ private PayloadStoreAsync payloadStore;
+
+ /**
+ * Constructs a new Amazon SQS extended async client to invoke service methods on
+ * Amazon SQS with extended functionality using the specified Amazon SQS
+ * client object.
+ *
+ *
+ * All service calls made using this client are asynchronous, and will return
+ * immediately with a {@link CompletableFuture} that completes when the operation
+ * completes or when an exception is thrown. Argument validation exceptions are thrown
+ * immediately, and not through the future.
+ *
+ *
+ * @param sqsClient
+ * The Amazon SQS async client to use to connect to Amazon SQS.
+ */
+ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient) {
+ this(sqsClient, new ExtendedAsyncClientConfiguration());
+ }
+
+ /**
+ * Constructs a new Amazon SQS extended client to invoke service methods on
+ * Amazon SQS with extended functionality using the specified Amazon SQS
+ * client object.
+ *
+ *
+ * All service calls made using this client are asynchronous, and will return
+ * immediately with a {@link CompletableFuture} that completes when the operation
+ * completes or when an exception is thrown. Argument validation exceptions are thrown
+ * immediately, and not through the future.
+ *
+ *
+ * @param sqsClient
+ * The Amazon SQS async client to use to connect to Amazon SQS.
+ * @param extendedClientConfig
+ * The extended client configuration options controlling the
+ * functionality of this client.
+ */
+ public AmazonSQSExtendedAsyncClient(SqsAsyncClient sqsClient,
+ ExtendedAsyncClientConfiguration extendedClientConfig) {
+ super(sqsClient);
+ this.clientConfiguration = new ExtendedAsyncClientConfiguration(extendedClientConfig);
+ S3AsyncDao s3Dao = new S3AsyncDao(clientConfiguration.getS3AsyncClient(),
+ clientConfiguration.getServerSideEncryptionStrategy(),
+ clientConfiguration.getObjectCannedACL());
+ this.payloadStore = new S3BackedPayloadStoreAsync(s3Dao, clientConfiguration.getS3BucketName());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture sendMessage(SendMessageRequest sendMessageRequest) {
+ // TODO: Clone request since it's modified in this method and will cause issues if the client reuses request
+ // object.
+ if (sendMessageRequest == null) {
+ String errorMessage = "sendMessageRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder();
+ sendMessageRequest = appendUserAgent(sendMessageRequestBuilder).build();
+
+ if (!clientConfiguration.isPayloadSupportEnabled()) {
+ return super.sendMessage(sendMessageRequest);
+ }
+
+ if (StringUtils.isEmpty(sendMessageRequest.messageBody())) {
+ String errorMessage = "messageBody cannot be null or empty.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ //Check message attributes for ExtendedClient related constraints
+ checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes());
+
+ if (clientConfiguration.isAlwaysThroughS3()
+ || isLarge(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest)) {
+ return storeMessageInS3(sendMessageRequest)
+ .thenCompose(modifiedRequest -> super.sendMessage(modifiedRequest));
+ }
+
+ return super.sendMessage(sendMessageRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture receiveMessage(ReceiveMessageRequest receiveMessageRequest) {
+ // TODO: Clone request since it's modified in this method and will cause issues if the client reuses request
+ // object.
+ if (receiveMessageRequest == null) {
+ String errorMessage = "receiveMessageRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ ReceiveMessageRequest.Builder receiveMessageRequestBuilder = receiveMessageRequest.toBuilder();
+ appendUserAgent(receiveMessageRequestBuilder);
+
+ if (!clientConfiguration.isPayloadSupportEnabled()) {
+ return super.receiveMessage(receiveMessageRequestBuilder.build());
+ }
+
+ // Remove before adding to avoid any duplicates
+ List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames());
+ messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
+ messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
+ receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames);
+ receiveMessageRequest = receiveMessageRequestBuilder.build();
+
+ return super.receiveMessage(receiveMessageRequest)
+ .thenCompose(receiveMessageResponse -> {
+ List messages = receiveMessageResponse.messages();
+
+ // Check for no messages. If so, no need to process further.
+ if (messages.isEmpty()) {
+ return CompletableFuture.completedFuture(messages);
+ }
+
+ List> modifiedMessageFutures = new ArrayList<>(messages.size());
+ for (Message message : messages) {
+ Message.Builder messageBuilder = message.toBuilder();
+
+ // For each received message check if they are stored in S3.
+ Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(
+ message.messageAttributes());
+ if (!largePayloadAttributeName.isPresent()) {
+ // Not S3
+ modifiedMessageFutures.add(CompletableFuture.completedFuture(messageBuilder.build()));
+ } else {
+ // In S3
+ final String largeMessagePointer = message.body()
+ .replace("com.amazon.sqs.javamessaging.MessageS3Pointer",
+ "software.amazon.payloadoffloading.PayloadS3Pointer");
+
+ // Retrieve original payload
+ modifiedMessageFutures.add(payloadStore.getOriginalPayload(largeMessagePointer)
+ .thenApply(originalPayload -> {
+ // Set original payload
+ messageBuilder.body(originalPayload);
+
+ // Remove the additional attribute before returning the message
+ // to user.
+ Map messageAttributes = new HashMap<>(
+ message.messageAttributes());
+ messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
+ messageBuilder.messageAttributes(messageAttributes);
+
+ // Embed s3 object pointer in the receipt handle.
+ String modifiedReceiptHandle = embedS3PointerInReceiptHandle(
+ message.receiptHandle(),
+ largeMessagePointer);
+ messageBuilder.receiptHandle(modifiedReceiptHandle);
+
+ return messageBuilder.build();
+ }));
+ }
+ }
+
+ // Convert list of message futures to a future list of messages.
+ return CompletableFuture.allOf(
+ modifiedMessageFutures.toArray(new CompletableFuture[modifiedMessageFutures.size()]))
+ .thenApply(v -> modifiedMessageFutures.stream()
+ .map(CompletableFuture::join)
+ .collect(Collectors.toList()));
+ })
+ .thenApply(modifiedMessages -> {
+ // Build response with modified message list.
+ ReceiveMessageResponse.Builder receiveMessageResponseBuilder = ReceiveMessageResponse.builder();
+ receiveMessageResponseBuilder.messages(modifiedMessages);
+ return receiveMessageResponseBuilder.build();
+ });
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture deleteMessage(DeleteMessageRequest deleteMessageRequest) {
+ if (deleteMessageRequest == null) {
+ String errorMessage = "deleteMessageRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ DeleteMessageRequest.Builder deleteMessageRequestBuilder = deleteMessageRequest.toBuilder();
+ appendUserAgent(deleteMessageRequestBuilder);
+
+ String receiptHandle = deleteMessageRequest.receiptHandle();
+ String origReceiptHandle = receiptHandle;
+ String messagePointer = null;
+
+ // Update original receipt handle if needed.
+ if (clientConfiguration.isPayloadSupportEnabled() && isS3ReceiptHandle(receiptHandle)) {
+ origReceiptHandle = getOrigReceiptHandle(receiptHandle);
+
+ // Delete pay load from S3 if needed
+ if (clientConfiguration.doesCleanupS3Payload()) {
+ messagePointer = getMessagePointerFromModifiedReceiptHandle(receiptHandle);
+ }
+ }
+
+ // The actual message to delete from SQS.
+ deleteMessageRequestBuilder.receiptHandle(origReceiptHandle);
+
+ // Check if message is in S3 or only in SQS.
+ if (messagePointer == null) {
+ // Delete only from SQS
+ return super.deleteMessage(deleteMessageRequestBuilder.build());
+ }
+
+ // Delete from SQS first, then S3.
+ final String messageToDeletePointer = messagePointer;
+ return super.deleteMessage(deleteMessageRequestBuilder.build())
+ .thenCompose(deleteMessageResponse ->
+ payloadStore.deleteOriginalPayload(messageToDeletePointer)
+ .thenApply(v -> deleteMessageResponse));
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture changeMessageVisibility(
+ ChangeMessageVisibilityRequest changeMessageVisibilityRequest) {
+
+ ChangeMessageVisibilityRequest.Builder changeMessageVisibilityRequestBuilder =
+ changeMessageVisibilityRequest.toBuilder();
+ if (isS3ReceiptHandle(changeMessageVisibilityRequest.receiptHandle())) {
+ changeMessageVisibilityRequestBuilder.receiptHandle(
+ getOrigReceiptHandle(changeMessageVisibilityRequest.receiptHandle()));
+ }
+ return amazonSqsToBeExtended.changeMessageVisibility(changeMessageVisibilityRequestBuilder.build());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture sendMessageBatch(
+ SendMessageBatchRequest sendMessageBatchRequestIn) {
+
+ if (sendMessageBatchRequestIn == null) {
+ String errorMessage = "sendMessageBatchRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ SendMessageBatchRequest.Builder sendMessageBatchRequestBuilder = sendMessageBatchRequestIn.toBuilder();
+ appendUserAgent(sendMessageBatchRequestBuilder);
+ SendMessageBatchRequest sendMessageBatchRequest = sendMessageBatchRequestBuilder.build();
+
+ if (!clientConfiguration.isPayloadSupportEnabled()) {
+ return super.sendMessageBatch(sendMessageBatchRequest);
+ }
+
+ List> batchEntryFutures = new ArrayList<>(
+ sendMessageBatchRequest.entries().size());
+ boolean hasS3Entries = false;
+ for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) {
+ //Check message attributes for ExtendedClient related constraints
+ checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes());
+
+ if (clientConfiguration.isAlwaysThroughS3()
+ || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) {
+ batchEntryFutures.add(storeMessageInS3(entry));
+ hasS3Entries = true;
+ } else {
+ batchEntryFutures.add(CompletableFuture.completedFuture(entry));
+ }
+ }
+
+ if (!hasS3Entries) {
+ return super.sendMessageBatch(sendMessageBatchRequest);
+ }
+
+ // Convert list of entry futures to a future list of entries.
+ return CompletableFuture.allOf(
+ batchEntryFutures.toArray(new CompletableFuture[batchEntryFutures.size()]))
+ .thenApply(v -> batchEntryFutures.stream()
+ .map(CompletableFuture::join)
+ .collect(Collectors.toList()))
+ .thenCompose(batchEntries -> {
+ SendMessageBatchRequest modifiedBatchRequest =
+ sendMessageBatchRequest.toBuilder().entries(batchEntries).build();
+ return super.sendMessageBatch(modifiedBatchRequest);
+ });
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture deleteMessageBatch(
+ DeleteMessageBatchRequest deleteMessageBatchRequest) {
+
+ if (deleteMessageBatchRequest == null) {
+ String errorMessage = "deleteMessageBatchRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ DeleteMessageBatchRequest.Builder deleteMessageBatchRequestBuilder = deleteMessageBatchRequest.toBuilder();
+ appendUserAgent(deleteMessageBatchRequestBuilder);
+
+ if (!clientConfiguration.isPayloadSupportEnabled()) {
+ return super.deleteMessageBatch(deleteMessageBatchRequest);
+ }
+
+ List entries = new ArrayList<>(deleteMessageBatchRequest.entries().size());
+ for (DeleteMessageBatchRequestEntry entry : deleteMessageBatchRequest.entries()) {
+ DeleteMessageBatchRequestEntry.Builder entryBuilder = entry.toBuilder();
+ String receiptHandle = entry.receiptHandle();
+ String origReceiptHandle = receiptHandle;
+
+ // Update original receipt handle if needed
+ if (isS3ReceiptHandle(receiptHandle)) {
+ origReceiptHandle = getOrigReceiptHandle(receiptHandle);
+ // Delete s3 payload if needed
+ if (clientConfiguration.doesCleanupS3Payload()) {
+ String messagePointer = getMessagePointerFromModifiedReceiptHandle(receiptHandle);
+ payloadStore.deleteOriginalPayload(messagePointer);
+ }
+ }
+
+ entryBuilder.receiptHandle(origReceiptHandle);
+ entries.add(entryBuilder.build());
+ }
+
+ deleteMessageBatchRequestBuilder.entries(entries);
+ return super.deleteMessageBatch(deleteMessageBatchRequestBuilder.build());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture changeMessageVisibilityBatch(
+ ChangeMessageVisibilityBatchRequest changeMessageVisibilityBatchRequest) {
+ List entries = new ArrayList<>(
+ changeMessageVisibilityBatchRequest.entries().size());
+ for (ChangeMessageVisibilityBatchRequestEntry entry : changeMessageVisibilityBatchRequest.entries()) {
+ ChangeMessageVisibilityBatchRequestEntry.Builder entryBuilder = entry.toBuilder();
+ if (isS3ReceiptHandle(entry.receiptHandle())) {
+ entryBuilder.receiptHandle(getOrigReceiptHandle(entry.receiptHandle()));
+ }
+ entries.add(entryBuilder.build());
+ }
+
+ return amazonSqsToBeExtended.changeMessageVisibilityBatch(
+ changeMessageVisibilityBatchRequest.toBuilder().entries(entries).build());
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture purgeQueue(PurgeQueueRequest purgeQueueRequest) {
+ LOG.warn("Calling purgeQueue deletes SQS messages without deleting their payload from S3.");
+
+ if (purgeQueueRequest == null) {
+ String errorMessage = "purgeQueueRequest cannot be null.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ PurgeQueueRequest.Builder purgeQueueRequestBuilder = purgeQueueRequest.toBuilder();
+ appendUserAgent(purgeQueueRequestBuilder);
+
+ return super.purgeQueue(purgeQueueRequestBuilder.build());
+ }
+
+ private CompletableFuture storeMessageInS3(SendMessageBatchRequestEntry batchEntry) {
+ // Read the content of the message from message body
+ String messageContentStr = batchEntry.messageBody();
+
+ Long messageContentSize = Util.getStringSizeInBytes(messageContentStr);
+
+ SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder();
+
+ batchEntryBuilder.messageAttributes(
+ updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize,
+ clientConfiguration.usesLegacyReservedAttributeName()));
+
+ // Store the message content in S3.
+ return storeOriginalPayload(messageContentStr)
+ .thenApply(largeMessagePointer -> {
+ batchEntryBuilder.messageBody(largeMessagePointer);
+ return batchEntryBuilder.build();
+ });
+ }
+
+ private CompletableFuture storeMessageInS3(SendMessageRequest sendMessageRequest) {
+ // Read the content of the message from message body
+ String messageContentStr = sendMessageRequest.messageBody();
+
+ Long messageContentSize = Util.getStringSizeInBytes(messageContentStr);
+
+ SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder();
+
+ sendMessageRequestBuilder.messageAttributes(
+ updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize,
+ clientConfiguration.usesLegacyReservedAttributeName()));
+
+ // Store the message content in S3.
+ return payloadStore.storeOriginalPayload(messageContentStr)
+ .thenApply(largeMessagePointer -> {
+ sendMessageRequestBuilder.messageBody(largeMessagePointer);
+ return sendMessageRequestBuilder.build();
+ });
+ }
+
+ private CompletableFuture storeOriginalPayload(String messageContentStr) {
+ String s3KeyPrefix = clientConfiguration.getS3KeyPrefix();
+ if (StringUtils.isBlank(s3KeyPrefix)) {
+ return payloadStore.storeOriginalPayload(messageContentStr);
+ }
+ return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID());
+ }
+
+ private static T appendUserAgent(final T builder) {
+ return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION);
+ }
+}
diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java
new file mode 100644
index 0000000..2547297
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientBase.java
@@ -0,0 +1,244 @@
+package com.amazon.sqs.javamessaging;
+
+import java.util.concurrent.CompletableFuture;
+import software.amazon.awssdk.services.sqs.SqsAsyncClient;
+import software.amazon.awssdk.services.sqs.model.AddPermissionRequest;
+import software.amazon.awssdk.services.sqs.model.AddPermissionResponse;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchResponse;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityRequest;
+import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityResponse;
+import software.amazon.awssdk.services.sqs.model.CreateQueueRequest;
+import software.amazon.awssdk.services.sqs.model.CreateQueueResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteQueueRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteQueueResponse;
+import software.amazon.awssdk.services.sqs.model.GetQueueAttributesRequest;
+import software.amazon.awssdk.services.sqs.model.GetQueueAttributesResponse;
+import software.amazon.awssdk.services.sqs.model.GetQueueUrlRequest;
+import software.amazon.awssdk.services.sqs.model.GetQueueUrlResponse;
+import software.amazon.awssdk.services.sqs.model.ListDeadLetterSourceQueuesRequest;
+import software.amazon.awssdk.services.sqs.model.ListDeadLetterSourceQueuesResponse;
+import software.amazon.awssdk.services.sqs.model.ListQueueTagsRequest;
+import software.amazon.awssdk.services.sqs.model.ListQueueTagsResponse;
+import software.amazon.awssdk.services.sqs.model.ListQueuesRequest;
+import software.amazon.awssdk.services.sqs.model.ListQueuesResponse;
+import software.amazon.awssdk.services.sqs.model.PurgeQueueRequest;
+import software.amazon.awssdk.services.sqs.model.PurgeQueueResponse;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
+import software.amazon.awssdk.services.sqs.model.RemovePermissionRequest;
+import software.amazon.awssdk.services.sqs.model.RemovePermissionResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
+import software.amazon.awssdk.services.sqs.model.SetQueueAttributesRequest;
+import software.amazon.awssdk.services.sqs.model.SetQueueAttributesResponse;
+import software.amazon.awssdk.services.sqs.model.TagQueueRequest;
+import software.amazon.awssdk.services.sqs.model.TagQueueResponse;
+import software.amazon.awssdk.services.sqs.model.UntagQueueRequest;
+import software.amazon.awssdk.services.sqs.model.UntagQueueResponse;
+
+abstract class AmazonSQSExtendedAsyncClientBase implements SqsAsyncClient {
+ SqsAsyncClient amazonSqsToBeExtended;
+
+ public AmazonSQSExtendedAsyncClientBase(SqsAsyncClient sqsClient) {
+ amazonSqsToBeExtended = sqsClient;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture sendMessage(SendMessageRequest sendMessageRequest) {
+ return amazonSqsToBeExtended.sendMessage(sendMessageRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture receiveMessage(ReceiveMessageRequest receiveMessageRequest) {
+ return amazonSqsToBeExtended.receiveMessage(receiveMessageRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture deleteMessage(DeleteMessageRequest deleteMessageRequest) {
+ return amazonSqsToBeExtended.deleteMessage(deleteMessageRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture setQueueAttributes(
+ SetQueueAttributesRequest setQueueAttributesRequest) {
+ return amazonSqsToBeExtended.setQueueAttributes(setQueueAttributesRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture changeMessageVisibilityBatch(
+ ChangeMessageVisibilityBatchRequest changeMessageVisibilityBatchRequest) {
+ return amazonSqsToBeExtended.changeMessageVisibilityBatch(changeMessageVisibilityBatchRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture changeMessageVisibility(
+ ChangeMessageVisibilityRequest changeMessageVisibilityRequest) {
+ return amazonSqsToBeExtended.changeMessageVisibility(changeMessageVisibilityRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture getQueueUrl(GetQueueUrlRequest getQueueUrlRequest) {
+ return amazonSqsToBeExtended.getQueueUrl(getQueueUrlRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture removePermission(
+ RemovePermissionRequest removePermissionRequest) {
+ return amazonSqsToBeExtended.removePermission(removePermissionRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture getQueueAttributes(
+ GetQueueAttributesRequest getQueueAttributesRequest) {
+ return amazonSqsToBeExtended.getQueueAttributes(getQueueAttributesRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture sendMessageBatch(
+ SendMessageBatchRequest sendMessageBatchRequest) {
+ return amazonSqsToBeExtended.sendMessageBatch(sendMessageBatchRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture purgeQueue(PurgeQueueRequest purgeQueueRequest) {
+ return amazonSqsToBeExtended.purgeQueue(purgeQueueRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture listDeadLetterSourceQueues(
+ ListDeadLetterSourceQueuesRequest listDeadLetterSourceQueuesRequest) {
+ return amazonSqsToBeExtended.listDeadLetterSourceQueues(listDeadLetterSourceQueuesRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture deleteQueue(DeleteQueueRequest deleteQueueRequest) {
+ return amazonSqsToBeExtended.deleteQueue(deleteQueueRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture listQueues(ListQueuesRequest listQueuesRequest) {
+ return amazonSqsToBeExtended.listQueues(listQueuesRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture listQueues() {
+ return amazonSqsToBeExtended.listQueues();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture deleteMessageBatch(
+ DeleteMessageBatchRequest deleteMessageBatchRequest) {
+ return amazonSqsToBeExtended.deleteMessageBatch(deleteMessageBatchRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture createQueue(CreateQueueRequest createQueueRequest) {
+ return amazonSqsToBeExtended.createQueue(createQueueRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture addPermission(AddPermissionRequest addPermissionRequest) {
+ return amazonSqsToBeExtended.addPermission(addPermissionRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture listQueueTags(final ListQueueTagsRequest listQueueTagsRequest) {
+ return amazonSqsToBeExtended.listQueueTags(listQueueTagsRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture tagQueue(final TagQueueRequest tagQueueRequest) {
+ return amazonSqsToBeExtended.tagQueue(tagQueueRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public CompletableFuture untagQueue(final UntagQueueRequest untagQueueRequest) {
+ return amazonSqsToBeExtended.untagQueue(untagQueueRequest);
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public String serviceName() {
+ return amazonSqsToBeExtended.serviceName();
+ }
+
+ /**
+ * {@inheritDoc}
+ */
+ @Override
+ public void close() {
+ amazonSqsToBeExtended.close();
+ }
+}
diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
index 3fb596f..68317d2 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java
@@ -15,9 +15,16 @@
package com.amazon.sqs.javamessaging;
-import java.lang.UnsupportedOperationException;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.checkMessageAttributes;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.embedS3PointerInReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getMessagePointerFromModifiedReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getOrigReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.getReservedAttributeNameIfPresent;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isLarge;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.isS3ReceiptHandle;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedClientUtil.updateMessageAttributePayloadSize;
+
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -27,15 +34,11 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import software.amazon.awssdk.awscore.AwsRequest;
-import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
import software.amazon.awssdk.awscore.exception.AwsServiceException;
-import software.amazon.awssdk.core.ApiName;
-import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.core.util.VersionInfo;
import software.amazon.awssdk.services.sqs.SqsClient;
-
import software.amazon.awssdk.services.sqs.model.BatchEntryIdsNotDistinctException;
import software.amazon.awssdk.services.sqs.model.BatchRequestTooLongException;
import software.amazon.awssdk.services.sqs.model.ChangeMessageVisibilityBatchRequest;
@@ -71,7 +74,6 @@
import software.amazon.awssdk.services.sqs.model.SqsException;
import software.amazon.awssdk.services.sqs.model.TooManyEntriesInBatchRequestException;
import software.amazon.awssdk.utils.StringUtils;
-import software.amazon.payloadoffloading.PayloadS3Pointer;
import software.amazon.payloadoffloading.PayloadStore;
import software.amazon.payloadoffloading.S3BackedPayloadStore;
import software.amazon.payloadoffloading.S3Dao;
@@ -102,9 +104,6 @@ public class AmazonSQSExtendedClient extends AmazonSQSExtendedClientBase impleme
static final String USER_AGENT_VERSION = VersionInfo.SDK_VERSION;
private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedClient.class);
- static final String LEGACY_RESERVED_ATTRIBUTE_NAME = "SQSLargePayloadSize";
- static final List RESERVED_ATTRIBUTE_NAMES = Arrays.asList(LEGACY_RESERVED_ATTRIBUTE_NAME,
- SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME);
private ExtendedClientConfiguration clientConfiguration;
private PayloadStore payloadStore;
@@ -205,9 +204,10 @@ public SendMessageResponse sendMessage(SendMessageRequest sendMessageRequest) {
}
//Check message attributes for ExtendedClient related constraints
- checkMessageAttributes(sendMessageRequest.messageAttributes());
+ checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest.messageAttributes());
- if (clientConfiguration.isAlwaysThroughS3() || isLarge(sendMessageRequest)) {
+ if (clientConfiguration.isAlwaysThroughS3()
+ || isLarge(clientConfiguration.getPayloadSizeThreshold(), sendMessageRequest)) {
sendMessageRequest = storeMessageInS3(sendMessageRequest);
}
return super.sendMessage(sendMessageRequest);
@@ -321,8 +321,8 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag
}
//Remove before adding to avoid any duplicates
List messageAttributeNames = new ArrayList<>(receiveMessageRequest.messageAttributeNames());
- messageAttributeNames.removeAll(RESERVED_ATTRIBUTE_NAMES);
- messageAttributeNames.addAll(RESERVED_ATTRIBUTE_NAMES);
+ messageAttributeNames.removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
+ messageAttributeNames.addAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
receiveMessageRequestBuilder.messageAttributeNames(messageAttributeNames);
receiveMessageRequest = receiveMessageRequestBuilder.build();
@@ -345,7 +345,7 @@ public ReceiveMessageResponse receiveMessage(ReceiveMessageRequest receiveMessag
// remove the additional attribute before returning the message
// to user.
Map messageAttributes = new HashMap<>(message.messageAttributes());
- messageAttributes.keySet().removeAll(RESERVED_ATTRIBUTE_NAMES);
+ messageAttributes.keySet().removeAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES);
messageBuilder.messageAttributes(messageAttributes);
// Embed s3 object pointer in the receipt handle.
@@ -621,9 +621,10 @@ public SendMessageBatchResponse sendMessageBatch(SendMessageBatchRequest sendMes
boolean hasS3Entries = false;
for (SendMessageBatchRequestEntry entry : sendMessageBatchRequest.entries()) {
//Check message attributes for ExtendedClient related constraints
- checkMessageAttributes(entry.messageAttributes());
+ checkMessageAttributes(clientConfiguration.getPayloadSizeThreshold(), entry.messageAttributes());
- if (clientConfiguration.isAlwaysThroughS3() || isLarge(entry)) {
+ if (clientConfiguration.isAlwaysThroughS3()
+ || isLarge(clientConfiguration.getPayloadSizeThreshold(), entry)) {
entry = storeMessageInS3(entry);
hasS3Entries = true;
}
@@ -838,121 +839,6 @@ public PurgeQueueResponse purgeQueue(PurgeQueueRequest purgeQueueRequest)
return super.purgeQueue(purgeQueueRequestBuilder.build());
}
- private void checkMessageAttributes(Map messageAttributes) {
- int msgAttributesSize = getMsgAttributesSize(messageAttributes);
- if (msgAttributesSize > clientConfiguration.getPayloadSizeThreshold()) {
- String errorMessage = "Total size of Message attributes is " + msgAttributesSize
- + " bytes which is larger than the threshold of " + clientConfiguration.getPayloadSizeThreshold()
- + " Bytes. Consider including the payload in the message body instead of message attributes.";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
-
- int messageAttributesNum = messageAttributes.size();
- if (messageAttributesNum > SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES) {
- String errorMessage = "Number of message attributes [" + messageAttributesNum
- + "] exceeds the maximum allowed for large-payload messages ["
- + SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES + "].";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
- Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(messageAttributes);
-
- if (largePayloadAttributeName.isPresent()) {
- String errorMessage = "Message attribute name " + largePayloadAttributeName.get()
- + " is reserved for use by SQS extended client.";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
- }
-
- /**
- * TODO: Wrap the message pointer as-is to the receiptHandle so that it can be generic
- * and does not use any LargeMessageStore implementation specific details.
- */
- private String embedS3PointerInReceiptHandle(String receiptHandle, String pointer) {
- PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(pointer);
- String s3MsgBucketName = s3Pointer.getS3BucketName();
- String s3MsgKey = s3Pointer.getS3Key();
-
- String modifiedReceiptHandle = SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + s3MsgBucketName
- + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER
- + s3MsgKey + SQSExtendedClientConstants.S3_KEY_MARKER + receiptHandle;
- return modifiedReceiptHandle;
- }
-
- private String getOrigReceiptHandle(String receiptHandle) {
- int secondOccurence = receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER,
- receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER) + 1);
- return receiptHandle.substring(secondOccurence + SQSExtendedClientConstants.S3_KEY_MARKER.length());
- }
-
- private String getFromReceiptHandleByMarker(String receiptHandle, String marker) {
- int firstOccurence = receiptHandle.indexOf(marker);
- int secondOccurence = receiptHandle.indexOf(marker, firstOccurence + 1);
- return receiptHandle.substring(firstOccurence + marker.length(), secondOccurence);
- }
-
- private boolean isS3ReceiptHandle(String receiptHandle) {
- return receiptHandle.contains(SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER)
- && receiptHandle.contains(SQSExtendedClientConstants.S3_KEY_MARKER);
- }
-
- private String getMessagePointerFromModifiedReceiptHandle(String receiptHandle) {
- String s3MsgBucketName = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER);
- String s3MsgKey = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_KEY_MARKER);
-
- PayloadS3Pointer payloadS3Pointer = new PayloadS3Pointer(s3MsgBucketName, s3MsgKey);
- return payloadS3Pointer.toJson();
- }
-
- private boolean isLarge(SendMessageRequest sendMessageRequest) {
- int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.messageAttributes());
- long msgBodySize = Util.getStringSizeInBytes(sendMessageRequest.messageBody());
- long totalMsgSize = msgAttributesSize + msgBodySize;
- return (totalMsgSize > clientConfiguration.getPayloadSizeThreshold());
- }
-
- private boolean isLarge(SendMessageBatchRequestEntry batchEntry) {
- int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes());
- long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody());
- long totalMsgSize = msgAttributesSize + msgBodySize;
- return (totalMsgSize > clientConfiguration.getPayloadSizeThreshold());
- }
-
- private Optional getReservedAttributeNameIfPresent(Map msgAttributes) {
- String reservedAttributeName = null;
- if (msgAttributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) {
- reservedAttributeName = SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME;
- } else if (msgAttributes.containsKey(LEGACY_RESERVED_ATTRIBUTE_NAME)) {
- reservedAttributeName = LEGACY_RESERVED_ATTRIBUTE_NAME;
- }
- return Optional.ofNullable(reservedAttributeName);
- }
-
- private int getMsgAttributesSize(Map msgAttributes) {
- int totalMsgAttributesSize = 0;
- for (Map.Entry entry : msgAttributes.entrySet()) {
- totalMsgAttributesSize += Util.getStringSizeInBytes(entry.getKey());
-
- MessageAttributeValue entryVal = entry.getValue();
- if (entryVal.dataType() != null) {
- totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.dataType());
- }
-
- String stringVal = entryVal.stringValue();
- if (stringVal != null) {
- totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.stringValue());
- }
-
- SdkBytes binaryVal = entryVal.binaryValue();
- if (binaryVal != null) {
- totalMsgAttributesSize += binaryVal.asByteArray().length;
- }
- }
- return totalMsgAttributesSize;
- }
-
private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEntry batchEntry) {
// Read the content of the message from message body
@@ -963,7 +849,8 @@ private SendMessageBatchRequestEntry storeMessageInS3(SendMessageBatchRequestEnt
SendMessageBatchRequestEntry.Builder batchEntryBuilder = batchEntry.toBuilder();
batchEntryBuilder.messageAttributes(
- updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize));
+ updateMessageAttributePayloadSize(batchEntry.messageAttributes(), messageContentSize,
+ clientConfiguration.usesLegacyReservedAttributeName()));
// Store the message content in S3.
String largeMessagePointer = storeOriginalPayload(messageContentStr);
@@ -982,7 +869,8 @@ private SendMessageRequest storeMessageInS3(SendMessageRequest sendMessageReques
SendMessageRequest.Builder sendMessageRequestBuilder = sendMessageRequest.toBuilder();
sendMessageRequestBuilder.messageAttributes(
- updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize));
+ updateMessageAttributePayloadSize(sendMessageRequest.messageAttributes(), messageContentSize,
+ clientConfiguration.usesLegacyReservedAttributeName()));
// Store the message content in S3.
String largeMessagePointer = storeOriginalPayload(messageContentStr);
@@ -999,32 +887,9 @@ private String storeOriginalPayload(String messageContentStr) {
return payloadStore.storeOriginalPayload(messageContentStr, s3KeyPrefix + UUID.randomUUID());
}
- private Map updateMessageAttributePayloadSize(
- Map messageAttributes, Long messageContentSize) {
- Map updatedMessageAttributes = new HashMap<>(messageAttributes);
-
- // Add a new message attribute as a flag
- MessageAttributeValue.Builder messageAttributeValueBuilder = MessageAttributeValue.builder();
- messageAttributeValueBuilder.dataType("Number");
- messageAttributeValueBuilder.stringValue(messageContentSize.toString());
- MessageAttributeValue messageAttributeValue = messageAttributeValueBuilder.build();
-
- if (!clientConfiguration.usesLegacyReservedAttributeName()) {
- updatedMessageAttributes.put(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, messageAttributeValue);
- } else {
- updatedMessageAttributes.put(LEGACY_RESERVED_ATTRIBUTE_NAME, messageAttributeValue);
- }
- return updatedMessageAttributes;
- }
-
@SuppressWarnings("unchecked")
private static T appendUserAgent(final T builder) {
- return (T) builder
- .overrideConfiguration(
- AwsRequestOverrideConfiguration.builder()
- .addApiName(ApiName.builder().name(USER_AGENT_NAME)
- .version(USER_AGENT_VERSION).build())
- .build());
+ return AmazonSQSExtendedClientUtil.appendUserAgent(builder, USER_AGENT_NAME, USER_AGENT_VERSION);
}
@Override
diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java
new file mode 100644
index 0000000..8bf1609
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientUtil.java
@@ -0,0 +1,200 @@
+package com.amazon.sqs.javamessaging;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import software.amazon.awssdk.awscore.AwsRequest;
+import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
+import software.amazon.awssdk.core.ApiName;
+import software.amazon.awssdk.core.SdkBytes;
+import software.amazon.awssdk.core.exception.SdkClientException;
+import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
+import software.amazon.awssdk.utils.StringUtils;
+import software.amazon.payloadoffloading.PayloadS3Pointer;
+import software.amazon.payloadoffloading.Util;
+
+public class AmazonSQSExtendedClientUtil {
+ private static final Log LOG = LogFactory.getLog(AmazonSQSExtendedClientUtil.class);
+
+ public static final String LEGACY_RESERVED_ATTRIBUTE_NAME = "SQSLargePayloadSize";
+ public static final List RESERVED_ATTRIBUTE_NAMES = Arrays.asList(LEGACY_RESERVED_ATTRIBUTE_NAME,
+ SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME);
+
+ public static void checkMessageAttributes(int payloadSizeThreshold, Map messageAttributes) {
+ int msgAttributesSize = getMsgAttributesSize(messageAttributes);
+ if (msgAttributesSize > payloadSizeThreshold) {
+ String errorMessage = "Total size of Message attributes is " + msgAttributesSize
+ + " bytes which is larger than the threshold of " + payloadSizeThreshold
+ + " Bytes. Consider including the payload in the message body instead of message attributes.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ int messageAttributesNum = messageAttributes.size();
+ if (messageAttributesNum > SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES) {
+ String errorMessage = "Number of message attributes [" + messageAttributesNum
+ + "] exceeds the maximum allowed for large-payload messages ["
+ + SQSExtendedClientConstants.MAX_ALLOWED_ATTRIBUTES + "].";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+ Optional largePayloadAttributeName = getReservedAttributeNameIfPresent(messageAttributes);
+
+ if (largePayloadAttributeName.isPresent()) {
+ String errorMessage = "Message attribute name " + largePayloadAttributeName.get()
+ + " is reserved for use by SQS extended client.";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+ }
+
+ public static Optional getReservedAttributeNameIfPresent(Map msgAttributes) {
+ String reservedAttributeName = null;
+ if (msgAttributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) {
+ reservedAttributeName = SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME;
+ } else if (msgAttributes.containsKey(LEGACY_RESERVED_ATTRIBUTE_NAME)) {
+ reservedAttributeName = LEGACY_RESERVED_ATTRIBUTE_NAME;
+ }
+ return Optional.ofNullable(reservedAttributeName);
+ }
+
+ public static String embedS3PointerInReceiptHandle(String receiptHandle, String pointer) {
+ PayloadS3Pointer s3Pointer = PayloadS3Pointer.fromJson(pointer);
+ String s3MsgBucketName = s3Pointer.getS3BucketName();
+ String s3MsgKey = s3Pointer.getS3Key();
+
+ return SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + s3MsgBucketName
+ + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER
+ + s3MsgKey + SQSExtendedClientConstants.S3_KEY_MARKER + receiptHandle;
+ }
+
+ public static String getOrigReceiptHandle(String receiptHandle) {
+ int secondOccurence = receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER,
+ receiptHandle.indexOf(SQSExtendedClientConstants.S3_KEY_MARKER) + 1);
+ return receiptHandle.substring(secondOccurence + SQSExtendedClientConstants.S3_KEY_MARKER.length());
+ }
+
+ public static boolean isS3ReceiptHandle(String receiptHandle) {
+ return receiptHandle.contains(SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER)
+ && receiptHandle.contains(SQSExtendedClientConstants.S3_KEY_MARKER);
+ }
+
+ public static String getMessagePointerFromModifiedReceiptHandle(String receiptHandle) {
+ String s3MsgBucketName = getFromReceiptHandleByMarker(
+ receiptHandle, SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER);
+ String s3MsgKey = getFromReceiptHandleByMarker(receiptHandle, SQSExtendedClientConstants.S3_KEY_MARKER);
+
+ PayloadS3Pointer payloadS3Pointer = new PayloadS3Pointer(s3MsgBucketName, s3MsgKey);
+ return payloadS3Pointer.toJson();
+ }
+
+ public static boolean isLarge(int payloadSizeThreshold, SendMessageRequest sendMessageRequest) {
+ int msgAttributesSize = getMsgAttributesSize(sendMessageRequest.messageAttributes());
+ long msgBodySize = Util.getStringSizeInBytes(sendMessageRequest.messageBody());
+ long totalMsgSize = msgAttributesSize + msgBodySize;
+ return (totalMsgSize > payloadSizeThreshold);
+ }
+
+ public static boolean isLarge(int payloadSizeThreshold, SendMessageBatchRequestEntry batchEntry) {
+ int msgAttributesSize = getMsgAttributesSize(batchEntry.messageAttributes());
+ long msgBodySize = Util.getStringSizeInBytes(batchEntry.messageBody());
+ long totalMsgSize = msgAttributesSize + msgBodySize;
+ return (totalMsgSize > payloadSizeThreshold);
+ }
+
+ public static Map updateMessageAttributePayloadSize(
+ Map messageAttributes, Long messageContentSize,
+ boolean usesLegacyReservedAttributeName) {
+ Map updatedMessageAttributes = new HashMap<>(messageAttributes);
+
+ // Add a new message attribute as a flag
+ MessageAttributeValue.Builder messageAttributeValueBuilder = MessageAttributeValue.builder();
+ messageAttributeValueBuilder.dataType("Number");
+ messageAttributeValueBuilder.stringValue(messageContentSize.toString());
+ MessageAttributeValue messageAttributeValue = messageAttributeValueBuilder.build();
+
+ if (!usesLegacyReservedAttributeName) {
+ updatedMessageAttributes.put(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME, messageAttributeValue);
+ } else {
+ updatedMessageAttributes.put(LEGACY_RESERVED_ATTRIBUTE_NAME, messageAttributeValue);
+ }
+ return updatedMessageAttributes;
+ }
+
+ @SuppressWarnings("unchecked")
+ public static T appendUserAgent(
+ final T builder, String userAgentName, String userAgentVersion) {
+ return (T) builder
+ .overrideConfiguration(
+ AwsRequestOverrideConfiguration.builder()
+ .addApiName(ApiName.builder().name(userAgentName)
+ .version(userAgentVersion).build())
+ .build());
+ }
+
+ private static String getFromReceiptHandleByMarker(String receiptHandle, String marker) {
+ int firstOccurence = receiptHandle.indexOf(marker);
+ int secondOccurence = receiptHandle.indexOf(marker, firstOccurence + 1);
+ return receiptHandle.substring(firstOccurence + marker.length(), secondOccurence);
+ }
+
+ private static int getMsgAttributesSize(Map msgAttributes) {
+ int totalMsgAttributesSize = 0;
+ for (Map.Entry entry : msgAttributes.entrySet()) {
+ totalMsgAttributesSize += Util.getStringSizeInBytes(entry.getKey());
+
+ MessageAttributeValue entryVal = entry.getValue();
+ if (entryVal.dataType() != null) {
+ totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.dataType());
+ }
+
+ String stringVal = entryVal.stringValue();
+ if (stringVal != null) {
+ totalMsgAttributesSize += Util.getStringSizeInBytes(entryVal.stringValue());
+ }
+
+ SdkBytes binaryVal = entryVal.binaryValue();
+ if (binaryVal != null) {
+ totalMsgAttributesSize += binaryVal.asByteArray().length;
+ }
+ }
+ return totalMsgAttributesSize;
+ }
+
+ public static String trimAndValidateS3KeyPrefix(String s3KeyPrefix) {
+ String trimmedPrefix = StringUtils.trimToEmpty(s3KeyPrefix);
+
+ if (trimmedPrefix.length() > SQSExtendedClientConstants.MAX_S3_KEY_PREFIX_LENGTH) {
+ String errorMessage = "The S3 key prefix length must not be greater than "
+ + SQSExtendedClientConstants.MAX_S3_KEY_PREFIX_LENGTH;
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ if (trimmedPrefix.startsWith(".") || trimmedPrefix.startsWith("/")) {
+ String errorMessage = "The S3 key prefix must not starts with '.' or '/'";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ if (trimmedPrefix.contains("..")) {
+ String errorMessage = "The S3 key prefix must not contains the string '..'";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ if (SQSExtendedClientConstants.INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN.matcher(trimmedPrefix).find()) {
+ String errorMessage = "The S3 key prefix contain invalid characters. The allowed characters are: letters, digits, '/', '_', '-', and '.'";
+ LOG.error(errorMessage);
+ throw SdkClientException.create(errorMessage);
+ }
+
+ return trimmedPrefix;
+ }
+}
diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java
new file mode 100644
index 0000000..e96fff2
--- /dev/null
+++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfiguration.java
@@ -0,0 +1,215 @@
+package com.amazon.sqs.javamessaging;
+
+import software.amazon.awssdk.annotations.NotThreadSafe;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
+import software.amazon.awssdk.utils.StringUtils;
+import software.amazon.payloadoffloading.PayloadStorageAsyncConfiguration;
+import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
+
+/**
+ * Amazon SQS extended client configuration options such as async Amazon S3 client,
+ * bucket name, and message size threshold for large-payload messages.
+ */
+@NotThreadSafe
+public class ExtendedAsyncClientConfiguration extends PayloadStorageAsyncConfiguration {
+
+ private boolean cleanupS3Payload = true;
+ private boolean useLegacyReservedAttributeName = true;
+ private boolean ignorePayloadNotFound = false;
+ private String s3KeyPrefix = "";
+
+ public ExtendedAsyncClientConfiguration() {
+ this.setPayloadSizeThreshold(SQSExtendedClientConstants.DEFAULT_MESSAGE_SIZE_THRESHOLD);
+ }
+
+ public ExtendedAsyncClientConfiguration(ExtendedAsyncClientConfiguration other) {
+ super(other);
+ this.cleanupS3Payload = other.doesCleanupS3Payload();
+ this.useLegacyReservedAttributeName = other.usesLegacyReservedAttributeName();
+ this.ignorePayloadNotFound = other.ignoresPayloadNotFound();
+ this.s3KeyPrefix = other.s3KeyPrefix;
+ }
+
+ /**
+ * Enables asynchronous support for payload messages.
+ * @param s3Async
+ * Amazon S3 client which is going to be used for storing
+ * payload messages.
+ * @param s3BucketName
+ * Name of the bucket which is going to be used for storing
+ * payload messages. The bucket must be already created and
+ * configured in s3.
+ * @param cleanupS3Payload
+ * If set to true, would handle deleting the S3 object as part
+ * of deleting the message from SQS queue. Otherwise, would not
+ * attempt to delete the object from S3. If opted to not delete S3
+ * objects its the responsibility to the message producer to handle
+ * the clean up appropriately.
+ */
+ public void setPayloadSupportEnabled(S3AsyncClient s3Async, String s3BucketName, boolean cleanupS3Payload) {
+ setPayloadSupportEnabled(s3Async, s3BucketName);
+ this.cleanupS3Payload = cleanupS3Payload;
+ }
+
+ /**
+ * Enables asynchronous support for payload messages.
+ * @param s3Async
+ * Amazon S3 client which is going to be used for storing
+ * payload messages.
+ * @param s3BucketName
+ * Name of the bucket which is going to be used for storing
+ * payload messages. The bucket must be already created and
+ * configured in s3.
+ * @param cleanupS3Payload
+ * If set to true, would handle deleting the S3 object as part
+ * of deleting the message from SQS queue. Otherwise, would not
+ * attempt to delete the object from S3. If opted to not delete S3
+ * objects its the responsibility to the message producer to handle
+ * the clean up appropriately.
+ */
+ public ExtendedAsyncClientConfiguration withPayloadSupportEnabled(
+ S3AsyncClient s3Async, String s3BucketName, boolean cleanupS3Payload) {
+ setPayloadSupportEnabled(s3Async, s3BucketName, cleanupS3Payload);
+ return this;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withPayloadSupportEnabled(S3AsyncClient s3Async, String s3BucketName) {
+ this.setPayloadSupportEnabled(s3Async, s3BucketName);
+ return this;
+ }
+
+ /**
+ * Disables the utilization legacy payload attribute name when sending messages.
+ */
+ public void setLegacyReservedAttributeNameDisabled() {
+ this.useLegacyReservedAttributeName = false;
+ }
+
+ /**
+ * Disables the utilization legacy payload attribute name when sending messages.
+ */
+ public ExtendedAsyncClientConfiguration withLegacyReservedAttributeNameDisabled() {
+ setLegacyReservedAttributeNameDisabled();
+ return this;
+ }
+
+ /**
+ * Sets whether or not messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3.
+ *
+ * @param ignorePayloadNotFound
+ * Whether or not messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3. Default: false
+ */
+ public void setIgnorePayloadNotFound(boolean ignorePayloadNotFound) {
+ this.ignorePayloadNotFound = ignorePayloadNotFound;
+ }
+
+ /**
+ * Sets whether or not messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3.
+ *
+ * @param ignorePayloadNotFound
+ * Whether or not messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3. Default: false
+ * @return the updated ExtendedAsyncClientConfiguration object.
+ */
+ public ExtendedAsyncClientConfiguration withIgnorePayloadNotFound(boolean ignorePayloadNotFound) {
+ setIgnorePayloadNotFound(ignorePayloadNotFound);
+ return this;
+ }
+ /**
+ * Sets a string that will be used as prefix of the S3 Key.
+ *
+ * @param s3KeyPrefix
+ * A S3 key prefix value
+ */
+ public void setS3KeyPrefix(String s3KeyPrefix) {
+ this.s3KeyPrefix = AmazonSQSExtendedClientUtil.trimAndValidateS3KeyPrefix(s3KeyPrefix);
+ }
+
+ /**
+ * Sets a string that will be used as prefix of the S3 Key.
+ *
+ * @param s3KeyPrefix
+ * A S3 key prefix value
+ *
+ * @return the updated ExtendedClientConfiguration object.
+ */
+ public ExtendedAsyncClientConfiguration withS3KeyPrefix(String s3KeyPrefix) {
+ setS3KeyPrefix(s3KeyPrefix);
+ return this;
+ }
+
+ /**
+ * Gets the S3 key prefix
+ * @return the prefix value which is being used for compose the S3 key.
+ */
+ public String getS3KeyPrefix() {
+ return this.s3KeyPrefix;
+ }
+
+ /**
+ * Checks whether or not clean up large objects in S3 is enabled.
+ *
+ * @return True if clean up is enabled when deleting the concerning SQS message.
+ * Default: true
+ */
+ public boolean doesCleanupS3Payload() {
+ return cleanupS3Payload;
+ }
+
+ /**
+ * Checks whether or not the configuration uses the legacy reserved attribute name.
+ *
+ * @return True if legacy reserved attribute name is used.
+ * Default: true
+ */
+
+ public boolean usesLegacyReservedAttributeName() {
+ return useLegacyReservedAttributeName;
+ }
+
+ /**
+ * Checks whether or not messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3.
+ *
+ * @return True if messages should be removed from Amazon SQS
+ * when payloads are not found in Amazon S3. Default: false
+ */
+ public boolean ignoresPayloadNotFound() {
+ return ignorePayloadNotFound;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withAlwaysThroughS3(boolean alwaysThroughS3) {
+ setAlwaysThroughS3(alwaysThroughS3);
+ return this;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withObjectCannedACL(ObjectCannedACL objectCannedACL) {
+ this.setObjectCannedACL(objectCannedACL);
+ return this;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withPayloadSizeThreshold(int payloadSizeThreshold) {
+ this.setPayloadSizeThreshold(payloadSizeThreshold);
+ return this;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withPayloadSupportDisabled() {
+ this.setPayloadSupportDisabled();
+ return this;
+ }
+
+ @Override
+ public ExtendedAsyncClientConfiguration withServerSideEncryption(ServerSideEncryptionStrategy serverSideEncryption) {
+ this.setServerSideEncryptionStrategy(serverSideEncryption);
+ return this;
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
index 19e2d39..75a30f8 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/ExtendedClientConfiguration.java
@@ -18,15 +18,11 @@
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import software.amazon.awssdk.annotations.NotThreadSafe;
-import software.amazon.awssdk.core.exception.SdkClientException;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
-import software.amazon.awssdk.utils.StringUtils;
import software.amazon.payloadoffloading.PayloadStorageConfiguration;
import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
-import java.util.regex.Pattern;
-
/**
* Amazon SQS extended client configuration options such as Amazon S3 client,
@@ -36,11 +32,6 @@
public class ExtendedClientConfiguration extends PayloadStorageConfiguration {
private static final Log LOG = LogFactory.getLog(ExtendedClientConfiguration.class);
- private static final int UUID_LENGTH = 36;
- private static final int MAX_S3_KEY_LENGTH = 1024;
- private static final int MAX_S3_KEY_PREFIX_LENGTH = MAX_S3_KEY_LENGTH - UUID_LENGTH;
- private static final Pattern INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN = Pattern.compile("[^a-zA-Z0-9./_-]");
-
private boolean cleanupS3Payload = true;
private boolean useLegacyReservedAttributeName = true;
private boolean ignorePayloadNotFound = false;
@@ -149,33 +140,7 @@ public ExtendedClientConfiguration withIgnorePayloadNotFound(boolean ignorePaylo
* A S3 key prefix value
*/
public void setS3KeyPrefix(String s3KeyPrefix) {
- String trimmedPrefix = StringUtils.trimToEmpty(s3KeyPrefix);
-
- if (trimmedPrefix.length() > MAX_S3_KEY_PREFIX_LENGTH) {
- String errorMessage = "The S3 key prefix length must not be greater than " + MAX_S3_KEY_PREFIX_LENGTH;
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
-
- if (trimmedPrefix.startsWith(".") || trimmedPrefix.startsWith("/")) {
- String errorMessage = "The S3 key prefix must not starts with '.' or '/'";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
-
- if (trimmedPrefix.contains("..")) {
- String errorMessage = "The S3 key prefix must not contains the string '..'";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
-
- if (INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN.matcher(trimmedPrefix).find()) {
- String errorMessage = "The S3 key prefix contain invalid characters. The allowed characters are: letters, digits, '/', '_', '-', and '.'";
- LOG.error(errorMessage);
- throw SdkClientException.create(errorMessage);
- }
-
- this.s3KeyPrefix = trimmedPrefix;
+ this.s3KeyPrefix = AmazonSQSExtendedClientUtil.trimAndValidateS3KeyPrefix(s3KeyPrefix);
}
/**
diff --git a/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java b/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java
index 97f1f22..ad5ad34 100644
--- a/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java
+++ b/src/main/java/com/amazon/sqs/javamessaging/SQSExtendedClientConstants.java
@@ -16,6 +16,8 @@
package com.amazon.sqs.javamessaging;
+import java.util.regex.Pattern;
+
public class SQSExtendedClientConstants {
// This constant is shared with SNSExtendedClient
// SNS team should be notified of any changes made to this
@@ -31,4 +33,12 @@ public class SQSExtendedClientConstants {
public static final String S3_BUCKET_NAME_MARKER = "-..s3BucketName..-";
public static final String S3_KEY_MARKER = "-..s3Key..-";
+
+ public static final int UUID_LENGTH = 36;
+
+ public static final int MAX_S3_KEY_LENGTH = 1024;
+
+ public static final int MAX_S3_KEY_PREFIX_LENGTH = MAX_S3_KEY_LENGTH - UUID_LENGTH;
+
+ public static final Pattern INVALID_S3_PREFIX_KEY_CHARACTERS_PATTERN = Pattern.compile("[^a-zA-Z0-9./_-]");
}
diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java
new file mode 100644
index 0000000..eb7de08
--- /dev/null
+++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedAsyncClientTest.java
@@ -0,0 +1,669 @@
+package com.amazon.sqs.javamessaging;
+
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedAsyncClient.USER_AGENT_NAME;
+import static com.amazon.sqs.javamessaging.AmazonSQSExtendedAsyncClient.USER_AGENT_VERSION;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotEquals;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isA;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.when;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
+import software.amazon.awssdk.core.ApiName;
+import software.amazon.awssdk.core.ResponseBytes;
+import software.amazon.awssdk.core.async.AsyncRequestBody;
+import software.amazon.awssdk.core.async.AsyncResponseTransformer;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.awssdk.services.s3.model.DeleteObjectRequest;
+import software.amazon.awssdk.services.s3.model.DeleteObjectResponse;
+import software.amazon.awssdk.services.s3.model.GetObjectRequest;
+import software.amazon.awssdk.services.s3.model.GetObjectResponse;
+import software.amazon.awssdk.services.s3.model.ObjectCannedACL;
+import software.amazon.awssdk.services.s3.model.PutObjectRequest;
+import software.amazon.awssdk.services.sqs.SqsAsyncClient;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageRequest;
+import software.amazon.awssdk.services.sqs.model.DeleteMessageResponse;
+import software.amazon.awssdk.services.sqs.model.Message;
+import software.amazon.awssdk.services.sqs.model.MessageAttributeValue;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageRequest;
+import software.amazon.awssdk.services.sqs.model.ReceiveMessageResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchRequestEntry;
+import software.amazon.awssdk.services.sqs.model.SendMessageBatchResponse;
+import software.amazon.awssdk.services.sqs.model.SendMessageRequest;
+import software.amazon.awssdk.services.sqs.model.SendMessageResponse;
+import software.amazon.awssdk.utils.ImmutableMap;
+import software.amazon.payloadoffloading.PayloadS3Pointer;
+import software.amazon.payloadoffloading.ServerSideEncryptionFactory;
+import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
+
+public class AmazonSQSExtendedAsyncClientTest {
+
+ private SqsAsyncClient extendedSqsWithDefaultConfig;
+ private SqsAsyncClient extendedSqsWithCustomKMS;
+ private SqsAsyncClient extendedSqsWithDefaultKMS;
+ private SqsAsyncClient extendedSqsWithGenericReservedAttributeName;
+ private SqsAsyncClient mockSqsBackend;
+ private S3AsyncClient mockS3;
+ private static final String S3_BUCKET_NAME = "test-bucket-name";
+ private static final String SQS_QUEUE_URL = "test-queue-url";
+ private static final String S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID = "test-customer-managed-kms-key-id";
+
+ private static final int LESS_THAN_SQS_SIZE_LIMIT = 3;
+ private static final int SQS_SIZE_LIMIT = 262144;
+ private static final int MORE_THAN_SQS_SIZE_LIMIT = SQS_SIZE_LIMIT + 1;
+ private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY = ServerSideEncryptionFactory.customerKey(S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID);
+ private static final ServerSideEncryptionStrategy SERVER_SIDE_ENCRYPTION_DEFAULT_STRATEGY = ServerSideEncryptionFactory.awsManagedCmk();
+
+ // should be > 1 and << SQS_SIZE_LIMIT
+ private static final int ARBITRARY_SMALLER_THRESHOLD = 500;
+
+ @BeforeEach
+ public void setupClients() {
+ mockS3 = mock(S3AsyncClient.class);
+ mockSqsBackend = mock(SqsAsyncClient.class);
+ when(mockS3.putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class))).thenReturn(
+ CompletableFuture.completedFuture(null));
+ when(mockS3.deleteObject(isA(DeleteObjectRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(DeleteObjectResponse.builder().build()));
+ when(mockSqsBackend.sendMessage(isA(SendMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(SendMessageResponse.builder().build()));
+ when(mockSqsBackend.sendMessageBatch(isA(SendMessageBatchRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(SendMessageBatchResponse.builder().build()));
+ when(mockSqsBackend.deleteMessage(isA(DeleteMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(DeleteMessageResponse.builder().build()));
+ when(mockSqsBackend.deleteMessageBatch(isA(DeleteMessageBatchRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(DeleteMessageBatchResponse.builder().build()));
+
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME);
+
+ ExtendedAsyncClientConfiguration extendedClientConfigurationWithCustomKMS = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_CUSTOM_STRATEGY);
+
+ ExtendedAsyncClientConfiguration extendedClientConfigurationWithDefaultKMS = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME)
+ .withServerSideEncryption(SERVER_SIDE_ENCRYPTION_DEFAULT_STRATEGY);
+
+ ExtendedAsyncClientConfiguration extendedClientConfigurationWithGenericReservedAttributeName = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withLegacyReservedAttributeNameDisabled();
+
+ extendedSqsWithDefaultConfig = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+ extendedSqsWithCustomKMS = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithCustomKMS));
+ extendedSqsWithDefaultKMS = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithDefaultKMS));
+ extendedSqsWithGenericReservedAttributeName = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfigurationWithGenericReservedAttributeName));
+ }
+
+ @Test
+ public void testWhenSendMessageWithLargePayloadSupportDisabledThenS3IsNotUsedAndSqsBackendIsResponsibleToFailItWithDeprecatedMethod() {
+ int messageLength = MORE_THAN_SQS_SIZE_LIMIT;
+ String messageBody = generateStringWithLength(messageLength);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportDisabled();
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageBody(messageBody)
+ .overrideConfiguration(
+ AwsRequestOverrideConfiguration.builder()
+ .addApiName(ApiName.builder().name(USER_AGENT_NAME).version(USER_AGENT_VERSION).build())
+ .build())
+ .build();
+ sqsExtended.sendMessage(messageRequest);
+
+ ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ verify(mockSqsBackend).sendMessage(argumentCaptor.capture());
+ assertEquals(messageRequest.queueUrl(), argumentCaptor.getValue().queueUrl());
+ assertEquals(messageRequest.messageBody(), argumentCaptor.getValue().messageBody());
+ assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).name(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).name());
+ assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).version(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).version());
+ }
+
+ @Test
+ public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStillStoredInS3WithDeprecatedMethod() {
+ int messageLength = LESS_THAN_SQS_SIZE_LIMIT;
+ String messageBody = generateStringWithLength(messageLength);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true);
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ sqsExtended.sendMessage(messageRequest).join();
+
+ verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonoredWithDeprecatedMethod() {
+ int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2;
+ String messageBody = generateStringWithLength(messageLength);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withPayloadSizeThreshold(ARBITRARY_SMALLER_THRESHOLD);
+
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ sqsExtended.sendMessage(messageRequest).join();
+ verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequestWithDeprecatedMethod() {
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME);
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+ when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(ReceiveMessageResponse.builder().build()));
+
+ ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build();
+ ReceiveMessageRequest expectedRequest = ReceiveMessageRequest.builder().build();
+
+ sqsExtended.receiveMessage(messageRequest).join();
+ assertEquals(expectedRequest, messageRequest);
+
+ sqsExtended.receiveMessage(messageRequest).join();
+ assertEquals(expectedRequest, messageRequest);
+ }
+
+ @Test
+ public void testWhenSendLargeMessageThenPayloadIsStoredInS3() {
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testWhenSendLargeMessage_WithoutKMS_ThenPayloadIsStoredInS3AndKMSKeyIdIsNotUsed() {
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class);
+ verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture());
+
+ assertNull(putObjectRequestArgumentCaptor.getValue().serverSideEncryption());
+ assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME);
+ }
+
+ @Test
+ public void testWhenSendLargeMessage_WithCustomKMS_ThenPayloadIsStoredInS3AndCorrectKMSKeyIdIsNotUsed() {
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithCustomKMS.sendMessage(messageRequest).join();
+
+ ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class);
+ verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture());
+
+ assertEquals(putObjectRequestArgumentCaptor.getValue().ssekmsKeyId(), S3_SERVER_SIDE_ENCRYPTION_KMS_KEY_ID);
+ assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME);
+ }
+
+ @Test
+ public void testWhenSendLargeMessage_WithDefaultKMS_ThenPayloadIsStoredInS3AndCorrectKMSKeyIdIsNotUsed() {
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultKMS.sendMessage(messageRequest).join();
+
+ ArgumentCaptor putObjectRequestArgumentCaptor = ArgumentCaptor.forClass(PutObjectRequest.class);
+ ArgumentCaptor requestBodyArgumentCaptor = ArgumentCaptor.forClass(AsyncRequestBody.class);
+ verify(mockS3, times(1)).putObject(putObjectRequestArgumentCaptor.capture(), requestBodyArgumentCaptor.capture());
+
+ assertTrue(putObjectRequestArgumentCaptor.getValue().serverSideEncryption() != null &&
+ putObjectRequestArgumentCaptor.getValue().ssekmsKeyId() == null);
+ assertEquals(putObjectRequestArgumentCaptor.getValue().bucket(), S3_BUCKET_NAME);
+ }
+
+ @Test
+ public void testSendLargeMessageWithDefaultConfigThenLegacyReservedAttributeNameIsUsed(){
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
+
+ Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
+ assertTrue(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertFalse(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME));
+
+ }
+
+ @Test
+ public void testSendLargeMessageWithGenericReservedAttributeNameConfigThenGenericReservedAttributeNameIsUsed(){
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithGenericReservedAttributeName.sendMessage(messageRequest).join();
+
+ ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
+
+ Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
+ assertTrue(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME));
+ assertFalse(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ }
+
+ @Test
+ public void testWhenSendSmallMessageThenS3IsNotUsed() {
+ String messageBody = generateStringWithLength(SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testWhenSendMessageWithLargePayloadSupportDisabledThenS3IsNotUsedAndSqsBackendIsResponsibleToFailIt() {
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportDisabled();
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder()
+ .queueUrl(SQS_QUEUE_URL)
+ .messageBody(messageBody)
+ .overrideConfiguration(
+ AwsRequestOverrideConfiguration.builder()
+ .addApiName(ApiName.builder().name(USER_AGENT_NAME).version(USER_AGENT_VERSION).build())
+ .build())
+ .build();
+ sqsExtended.sendMessage(messageRequest).join();
+
+ ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+
+ verify(mockS3, never()).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ verify(mockSqsBackend).sendMessage(argumentCaptor.capture());
+ assertEquals(messageRequest.queueUrl(), argumentCaptor.getValue().queueUrl());
+ assertEquals(messageRequest.messageBody(), argumentCaptor.getValue().messageBody());
+ assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).name(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).name());
+ assertEquals(messageRequest.overrideConfiguration().get().apiNames().get(0).version(), argumentCaptor.getValue().overrideConfiguration().get().apiNames().get(0).version());
+ }
+
+ @Test
+ public void testWhenSendMessageWithAlwaysThroughS3AndMessageIsSmallThenItIsStillStoredInS3() {
+ String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true);
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ sqsExtended.sendMessage(messageRequest).join();
+
+ verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored() {
+ int messageLength = ARBITRARY_SMALLER_THRESHOLD * 2;
+ String messageBody = generateStringWithLength(messageLength);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withPayloadSizeThreshold(ARBITRARY_SMALLER_THRESHOLD);
+
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ sqsExtended.sendMessage(messageRequest).join();
+ verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() {
+ when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(ReceiveMessageResponse.builder().build()));
+
+ ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build();
+
+ ReceiveMessageRequest expectedRequest = ReceiveMessageRequest.builder().build();
+
+ extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join();
+ assertEquals(expectedRequest, messageRequest);
+
+ extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join();
+ assertEquals(expectedRequest, messageRequest);
+ }
+
+ @Test
+ public void testReceiveMessage_when_MessageIsLarge_legacyReservedAttributeUsed() throws Exception {
+ testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME);
+ }
+
+ @Test
+ public void testReceiveMessage_when_MessageIsLarge_ReservedAttributeUsed() throws Exception {
+ testReceiveMessage_when_MessageIsLarge(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME);
+ }
+
+ @Test
+ public void testReceiveMessage_when_MessageIsSmall() throws Exception {
+ String expectedMessageAttributeName = "AnyMessageAttribute";
+ String expectedMessage = "SmallMessage";
+ Message message = Message.builder()
+ .messageAttributes(ImmutableMap.of(expectedMessageAttributeName, MessageAttributeValue.builder().build()))
+ .body(expectedMessage)
+ .build();
+ when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(ReceiveMessageResponse.builder().messages(message).build()));
+
+ ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build();
+ ReceiveMessageResponse actualReceiveMessageResponse = extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join();
+ Message actualMessage = actualReceiveMessageResponse.messages().get(0);
+
+ assertEquals(expectedMessage, actualMessage.body());
+ assertTrue(actualMessage.messageAttributes().containsKey(expectedMessageAttributeName));
+ assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES));
+ verifyNoInteractions(mockS3);
+ }
+
+ @Test
+ public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() {
+ // This creates 10 messages, out of which only two are below the threshold (100K and 200K),
+ // and the other 8 are above the threshold
+
+ int[] messageLengthForCounter = new int[] {
+ 100_000,
+ 300_000,
+ 400_000,
+ 500_000,
+ 600_000,
+ 700_000,
+ 800_000,
+ 900_000,
+ 200_000,
+ 1000_000
+ };
+
+ List batchEntries = new ArrayList();
+ for (int i = 0; i < 10; i++) {
+ int messageLength = messageLengthForCounter[i];
+ String messageBody = generateStringWithLength(messageLength);
+ SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder()
+ .id("entry_" + i)
+ .messageBody(messageBody)
+ .build();
+ batchEntries.add(entry);
+ }
+
+ SendMessageBatchRequest
+ batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build();
+ extendedSqsWithDefaultConfig.sendMessageBatch(batchRequest).join();
+
+ // There should be 8 puts for the 8 messages above the threshold
+ verify(mockS3, times(8)).putObject(isA(PutObjectRequest.class), isA(AsyncRequestBody.class));
+ }
+
+ @Test
+ public void testWhenMessageBatchIsLargeS3PointerIsCorrectlySentToSQSAndNotOriginalMessage() {
+ String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withAlwaysThroughS3(true);
+
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ List batchEntries = new ArrayList();
+ for (int i = 0; i < 10; i++) {
+ SendMessageBatchRequestEntry entry = SendMessageBatchRequestEntry.builder()
+ .id("entry_" + i)
+ .messageBody(messageBody)
+ .build();
+ batchEntries.add(entry);
+ }
+ SendMessageBatchRequest batchRequest = SendMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(batchEntries).build();
+
+ sqsExtended.sendMessageBatch(batchRequest).join();
+
+ ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageBatchRequest.class);
+ verify(mockSqsBackend).sendMessageBatch(sendMessageRequestCaptor.capture());
+
+ for (SendMessageBatchRequestEntry entry : sendMessageRequestCaptor.getValue().entries()) {
+ assertNotEquals(messageBody, entry.messageBody());
+ }
+ }
+
+ @Test
+ public void testWhenSmallMessageIsSentThenNoAttributeIsAdded() {
+ String messageBody = generateStringWithLength(LESS_THAN_SQS_SIZE_LIMIT);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
+
+ Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
+ assertTrue(attributes.isEmpty());
+ }
+
+ @Test
+ public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() {
+ int messageLength = MORE_THAN_SQS_SIZE_LIMIT;
+ String messageBody = generateStringWithLength(messageLength);
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ extendedSqsWithDefaultConfig.sendMessage(messageRequest).join();
+
+ ArgumentCaptor sendMessageRequestCaptor = ArgumentCaptor.forClass(SendMessageRequest.class);
+ verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
+
+ Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
+ assertEquals("Number", attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType());
+ assertEquals(messageLength, (int) Integer.parseInt(attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()));
+ }
+
+ @Test
+ public void testDefaultExtendedClientDeletesSmallMessage() {
+ // given
+ String receiptHandle = UUID.randomUUID().toString();
+ DeleteMessageRequest
+ deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(receiptHandle).build();
+
+ // when
+ extendedSqsWithDefaultConfig.deleteMessage(deleteRequest).join();
+
+ // then
+ ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class);
+ verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture());
+ assertEquals(receiptHandle, deleteRequestCaptor.getValue().receiptHandle());
+ verifyNoInteractions(mockS3);
+ }
+
+ @Test
+ public void testDefaultExtendedClientDeletesObjectS3UponMessageDelete() {
+ // given
+ String randomS3Key = UUID.randomUUID().toString();
+ String originalReceiptHandle = UUID.randomUUID().toString();
+ String largeMessageReceiptHandle = getLargeReceiptHandle(randomS3Key, originalReceiptHandle);
+ DeleteMessageRequest deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(largeMessageReceiptHandle).build();
+
+ // when
+ extendedSqsWithDefaultConfig.deleteMessage(deleteRequest).join();
+
+ // then
+ ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class);
+ verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture());
+ assertEquals(originalReceiptHandle, deleteRequestCaptor.getValue().receiptHandle());
+ DeleteObjectRequest deleteObjectRequest = DeleteObjectRequest.builder().bucket(S3_BUCKET_NAME).key(randomS3Key).build();
+ verify(mockS3).deleteObject(eq(deleteObjectRequest));
+ }
+
+ @Test
+ public void testExtendedClientConfiguredDoesNotDeleteObjectFromS3UponDelete() {
+ // given
+ String randomS3Key = UUID.randomUUID().toString();
+ String originalReceiptHandle = UUID.randomUUID().toString();
+ String largeMessageReceiptHandle = getLargeReceiptHandle(randomS3Key, originalReceiptHandle);
+ DeleteMessageRequest deleteRequest = DeleteMessageRequest.builder().queueUrl(SQS_QUEUE_URL).receiptHandle(largeMessageReceiptHandle).build();
+
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME, false);
+
+ SqsAsyncClient extendedSqs = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ // when
+ extendedSqs.deleteMessage(deleteRequest).join();
+
+ // then
+ ArgumentCaptor deleteRequestCaptor = ArgumentCaptor.forClass(DeleteMessageRequest.class);
+ verify(mockSqsBackend).deleteMessage(deleteRequestCaptor.capture());
+ assertEquals(originalReceiptHandle, deleteRequestCaptor.getValue().receiptHandle());
+ verifyNoInteractions(mockS3);
+ }
+
+ @Test
+ public void testExtendedClientConfiguredDoesNotDeletesObjectsFromS3UponDeleteBatch() {
+ // given
+ int batchSize = 10;
+ List originalReceiptHandles = IntStream.range(0, batchSize)
+ .mapToObj(i -> UUID.randomUUID().toString())
+ .collect(Collectors.toList());
+ DeleteMessageBatchRequest deleteBatchRequest = generateLargeDeleteBatchRequest(originalReceiptHandles);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME, false);
+ SqsAsyncClient extendedSqs = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ // when
+ extendedSqs.deleteMessageBatch(deleteBatchRequest).join();
+
+ // then
+ ArgumentCaptor deleteBatchRequestCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class);
+ verify(mockSqsBackend, times(1)).deleteMessageBatch(deleteBatchRequestCaptor.capture());
+ DeleteMessageBatchRequest request = deleteBatchRequestCaptor.getValue();
+ assertEquals(originalReceiptHandles.size(), request.entries().size());
+ IntStream.range(0, originalReceiptHandles.size()).forEach(i -> assertEquals(
+ originalReceiptHandles.get(i),
+ request.entries().get(i).receiptHandle()));
+ verifyNoInteractions(mockS3);
+ }
+
+ @Test
+ public void testDefaultExtendedClientDeletesObjectsFromS3UponDeleteBatch() {
+ // given
+ int batchSize = 10;
+ List originalReceiptHandles = IntStream.range(0, batchSize)
+ .mapToObj(i -> UUID.randomUUID().toString())
+ .collect(Collectors.toList());
+ DeleteMessageBatchRequest deleteBatchRequest = generateLargeDeleteBatchRequest(originalReceiptHandles);
+
+ // when
+ extendedSqsWithDefaultConfig.deleteMessageBatch(deleteBatchRequest).join();
+
+ // then
+ ArgumentCaptor deleteBatchRequestCaptor = ArgumentCaptor.forClass(DeleteMessageBatchRequest.class);
+ verify(mockSqsBackend, times(1)).deleteMessageBatch(deleteBatchRequestCaptor.capture());
+ DeleteMessageBatchRequest request = deleteBatchRequestCaptor.getValue();
+ assertEquals(originalReceiptHandles.size(), request.entries().size());
+ IntStream.range(0, originalReceiptHandles.size()).forEach(i -> assertEquals(
+ originalReceiptHandles.get(i),
+ request.entries().get(i).receiptHandle()));
+ verify(mockS3, times(batchSize)).deleteObject(any(DeleteObjectRequest.class));
+ }
+
+ @Test
+ public void testWhenSendMessageWIthCannedAccessControlListDefined() {
+ ObjectCannedACL expected = ObjectCannedACL.BUCKET_OWNER_FULL_CONTROL;
+ String messageBody = generateStringWithLength(MORE_THAN_SQS_SIZE_LIMIT);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration()
+ .withPayloadSupportEnabled(mockS3, S3_BUCKET_NAME).withObjectCannedACL(expected);
+ SqsAsyncClient sqsExtended = spy(new AmazonSQSExtendedAsyncClient(mockSqsBackend, extendedClientConfiguration));
+
+ SendMessageRequest messageRequest = SendMessageRequest.builder().queueUrl(SQS_QUEUE_URL).messageBody(messageBody).build();
+ sqsExtended.sendMessage(messageRequest).join();
+
+ ArgumentCaptor captor = ArgumentCaptor.forClass(PutObjectRequest.class);
+
+ verify(mockS3).putObject(captor.capture(), any(AsyncRequestBody.class));
+
+ assertEquals(expected, captor.getValue().acl());
+ }
+
+ private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName) throws Exception {
+ String pointer = new PayloadS3Pointer(S3_BUCKET_NAME, "S3Key").toJson();
+ Message message = Message.builder()
+ .messageAttributes(ImmutableMap.of(reservedAttributeName, MessageAttributeValue.builder().build()))
+ .body(pointer)
+ .build();
+ String expectedMessage = "LargeMessage";
+ GetObjectRequest getObjectRequest = GetObjectRequest.builder()
+ .bucket(S3_BUCKET_NAME)
+ .key("S3Key")
+ .build();
+
+ ResponseBytes s3Object = ResponseBytes.fromByteArray(
+ GetObjectResponse.builder().build(),
+ expectedMessage.getBytes(StandardCharsets.UTF_8));
+ when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(
+ CompletableFuture.completedFuture(ReceiveMessageResponse.builder().messages(message).build()));
+ when(mockS3.getObject(isA(GetObjectRequest.class), isA(AsyncResponseTransformer.class))).thenReturn(
+ CompletableFuture.completedFuture(s3Object));
+
+ ReceiveMessageRequest messageRequest = ReceiveMessageRequest.builder().build();
+ ReceiveMessageResponse actualReceiveMessageResponse = extendedSqsWithDefaultConfig.receiveMessage(messageRequest).join();
+ Message actualMessage = actualReceiveMessageResponse.messages().get(0);
+
+ assertEquals(expectedMessage, actualMessage.body());
+ assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES));
+ verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class), isA(AsyncResponseTransformer.class));
+ }
+
+ private DeleteMessageBatchRequest generateLargeDeleteBatchRequest(List originalReceiptHandles) {
+ List deleteEntries = IntStream.range(0, originalReceiptHandles.size())
+ .mapToObj(i -> DeleteMessageBatchRequestEntry.builder()
+ .id(Integer.toString(i))
+ .receiptHandle(getSampleLargeReceiptHandle(originalReceiptHandles.get(i)))
+ .build())
+ .collect(Collectors.toList());
+
+ return DeleteMessageBatchRequest.builder().queueUrl(SQS_QUEUE_URL).entries(deleteEntries).build();
+ }
+
+ private String getLargeReceiptHandle(String s3Key, String originalReceiptHandle) {
+ return SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + S3_BUCKET_NAME
+ + SQSExtendedClientConstants.S3_BUCKET_NAME_MARKER + SQSExtendedClientConstants.S3_KEY_MARKER
+ + s3Key + SQSExtendedClientConstants.S3_KEY_MARKER + originalReceiptHandle;
+ }
+
+ private String getSampleLargeReceiptHandle(String originalReceiptHandle) {
+ return getLargeReceiptHandle(UUID.randomUUID().toString(), originalReceiptHandle);
+ }
+
+ private String generateStringWithLength(int messageLength) {
+ char[] charArray = new char[messageLength];
+ Arrays.fill(charArray, 'x');
+ return new String(charArray);
+ }
+}
diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
index eda2179..fd58b0b 100644
--- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
+++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java
@@ -51,7 +51,6 @@
import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
@@ -295,7 +294,7 @@ public void testSendLargeMessageWithDefaultConfigThenLegacyReservedAttributeName
verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
- assertTrue(attributes.containsKey(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertTrue(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
assertFalse(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME));
}
@@ -310,7 +309,7 @@ public void testSendLargeMessageWithGenericReservedAttributeNameConfigThenGeneri
Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
assertTrue(attributes.containsKey(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME));
- assertFalse(attributes.containsKey(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME));
+ assertFalse(attributes.containsKey(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME));
}
@Test
@@ -396,7 +395,7 @@ public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessag
@Test
public void testReceiveMessage_when_MessageIsLarge_legacyReservedAttributeUsed() {
- testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME);
+ testReceiveMessage_when_MessageIsLarge(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME);
}
@Test
@@ -420,7 +419,7 @@ public void testReceiveMessage_when_MessageIsSmall() {
assertEquals(expectedMessage, actualMessage.body());
assertTrue(actualMessage.messageAttributes().containsKey(expectedMessageAttributeName));
- assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClient.RESERVED_ATTRIBUTE_NAMES));
+ assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES));
verifyNoInteractions(mockS3);
}
@@ -514,8 +513,8 @@ public void testWhenLargeMessageIsSentThenAttributeWithPayloadSizeIsAdded() {
verify(mockSqsBackend).sendMessage(sendMessageRequestCaptor.capture());
Map attributes = sendMessageRequestCaptor.getValue().messageAttributes();
- assertEquals("Number", attributes.get(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType());
- assertEquals(messageLength, Integer.parseInt(attributes.get(AmazonSQSExtendedClient.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()));
+ assertEquals("Number", attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).dataType());
+ assertEquals(messageLength, Integer.parseInt(attributes.get(AmazonSQSExtendedClientUtil.LEGACY_RESERVED_ATTRIBUTE_NAME).stringValue()));
}
@Test
@@ -693,7 +692,7 @@ private void testReceiveMessage_when_MessageIsLarge(String reservedAttributeName
Message actualMessage = actualReceiveMessageResponse.messages().get(0);
assertEquals(expectedMessage, actualMessage.body());
- assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClient.RESERVED_ATTRIBUTE_NAMES));
+ assertFalse(actualMessage.messageAttributes().keySet().containsAll(AmazonSQSExtendedClientUtil.RESERVED_ATTRIBUTE_NAMES));
verify(mockS3, times(1)).getObject(isA(GetObjectRequest.class));
}
diff --git a/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java
new file mode 100644
index 0000000..879f098
--- /dev/null
+++ b/src/test/java/com/amazon/sqs/javamessaging/ExtendedAsyncClientConfigurationTest.java
@@ -0,0 +1,88 @@
+package com.amazon.sqs.javamessaging;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNotNull;
+import static org.junit.jupiter.api.Assertions.assertNotSame;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.mockito.Mockito.mock;
+
+import org.junit.jupiter.api.Test;
+import software.amazon.awssdk.services.s3.S3AsyncClient;
+import software.amazon.payloadoffloading.ServerSideEncryptionFactory;
+import software.amazon.payloadoffloading.ServerSideEncryptionStrategy;
+
+/**
+ * Tests the ExtendedAsyncClientConfiguration class.
+ */
+public class ExtendedAsyncClientConfigurationTest {
+
+ private static final String s3BucketName = "test-bucket-name";
+ private static final String s3ServerSideEncryptionKMSKeyId = "test-customer-managed-kms-key-id";
+ private static final ServerSideEncryptionStrategy serverSideEncryptionStrategy =
+ ServerSideEncryptionFactory.customerKey(s3ServerSideEncryptionKMSKeyId);
+
+ @Test
+ public void testCopyConstructor() {
+ S3AsyncClient s3 = mock(S3AsyncClient.class);
+
+ boolean alwaysThroughS3 = true;
+ int messageSizeThreshold = 500;
+ boolean doesCleanupS3Payload = false;
+
+ ExtendedAsyncClientConfiguration extendedClientConfig = new ExtendedAsyncClientConfiguration();
+
+ extendedClientConfig.withPayloadSupportEnabled(s3, s3BucketName, doesCleanupS3Payload)
+ .withAlwaysThroughS3(alwaysThroughS3).withPayloadSizeThreshold(messageSizeThreshold)
+ .withServerSideEncryption(serverSideEncryptionStrategy);
+
+ ExtendedAsyncClientConfiguration newExtendedClientConfig = new ExtendedAsyncClientConfiguration(extendedClientConfig);
+
+ assertEquals(s3, newExtendedClientConfig.getS3AsyncClient());
+ assertEquals(s3BucketName, newExtendedClientConfig.getS3BucketName());
+ assertEquals(serverSideEncryptionStrategy, newExtendedClientConfig.getServerSideEncryptionStrategy());
+ assertTrue(newExtendedClientConfig.isPayloadSupportEnabled());
+ assertEquals(doesCleanupS3Payload, newExtendedClientConfig.doesCleanupS3Payload());
+ assertEquals(alwaysThroughS3, newExtendedClientConfig.isAlwaysThroughS3());
+ assertEquals(messageSizeThreshold, newExtendedClientConfig.getPayloadSizeThreshold());
+
+ assertNotSame(newExtendedClientConfig, extendedClientConfig);
+ }
+
+ @Test
+ public void testLargePayloadSupportEnabledWithDefaultDeleteFromS3Config() {
+ S3AsyncClient s3 = mock(S3AsyncClient.class);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration();
+ extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName);
+
+ assertTrue(extendedClientConfiguration.isPayloadSupportEnabled());
+ assertTrue(extendedClientConfiguration.doesCleanupS3Payload());
+ assertNotNull(extendedClientConfiguration.getS3AsyncClient());
+ assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName());
+
+ }
+
+ @Test
+ public void testLargePayloadSupportEnabledWithDeleteFromS3Enabled() {
+ S3AsyncClient s3 = mock(S3AsyncClient.class);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration();
+ extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName, true);
+
+ assertTrue(extendedClientConfiguration.isPayloadSupportEnabled());
+ assertTrue(extendedClientConfiguration.doesCleanupS3Payload());
+ assertNotNull(extendedClientConfiguration.getS3AsyncClient());
+ assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName());
+ }
+
+ @Test
+ public void testLargePayloadSupportEnabledWithDeleteFromS3Disabled() {
+ S3AsyncClient s3 = mock(S3AsyncClient.class);
+ ExtendedAsyncClientConfiguration extendedClientConfiguration = new ExtendedAsyncClientConfiguration();
+ extendedClientConfiguration.setPayloadSupportEnabled(s3, s3BucketName, false);
+
+ assertTrue(extendedClientConfiguration.isPayloadSupportEnabled());
+ assertFalse(extendedClientConfiguration.doesCleanupS3Payload());
+ assertNotNull(extendedClientConfiguration.getS3AsyncClient());
+ assertEquals(s3BucketName, extendedClientConfiguration.getS3BucketName());
+ }
+}