Skip to content

Commit 01fa5d2

Browse files
authored
TM: downloadDirectory refactor how SDK sends concurrent download file requests (#3867)
* Refactor async buffering subscriber
1 parent fd0a4c2 commit 01fa5d2

File tree

2 files changed

+47
-93
lines changed

2 files changed

+47
-93
lines changed

services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/AsyncBufferingSubscriber.java

+17-73
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,14 @@
1515

1616
package software.amazon.awssdk.transfer.s3.internal;
1717

18-
import java.util.Optional;
1918
import java.util.concurrent.CompletableFuture;
20-
import java.util.concurrent.atomic.AtomicBoolean;
2119
import java.util.concurrent.atomic.AtomicInteger;
2220
import java.util.function.Function;
2321
import org.reactivestreams.Subscriber;
2422
import org.reactivestreams.Subscription;
2523
import software.amazon.awssdk.annotations.SdkInternalApi;
2624
import software.amazon.awssdk.utils.Logger;
2725
import software.amazon.awssdk.utils.Validate;
28-
import software.amazon.awssdk.utils.async.DemandIgnoringSubscription;
29-
import software.amazon.awssdk.utils.async.StoringSubscriber;
3026

3127
/**
3228
* An implementation of {@link Subscriber} that execute the provided function for every event and limits the number of concurrent
@@ -41,20 +37,16 @@ public class AsyncBufferingSubscriber<T> implements Subscriber<T> {
4137
private final Function<T, CompletableFuture<?>> consumer;
4238
private final int maxConcurrentExecutions;
4339
private final AtomicInteger numRequestsInFlight;
44-
private final AtomicBoolean isDelivering = new AtomicBoolean(false);
45-
private volatile boolean isStreamingDone;
40+
private volatile boolean upstreamDone;
4641
private Subscription subscription;
4742

48-
private final StoringSubscriber<T> storingSubscriber;
49-
5043
public AsyncBufferingSubscriber(Function<T, CompletableFuture<?>> consumer,
5144
CompletableFuture<Void> returnFuture,
5245
int maxConcurrentExecutions) {
5346
this.returnFuture = returnFuture;
5447
this.consumer = consumer;
5548
this.maxConcurrentExecutions = maxConcurrentExecutions;
5649
this.numRequestsInFlight = new AtomicInteger(0);
57-
this.storingSubscriber = new StoringSubscriber<>(Integer.MAX_VALUE);
5850
}
5951

6052
@Override
@@ -65,89 +57,41 @@ public void onSubscribe(Subscription subscription) {
6557
subscription.cancel();
6658
return;
6759
}
68-
storingSubscriber.onSubscribe(new DemandIgnoringSubscription(subscription));
6960
this.subscription = subscription;
7061
subscription.request(maxConcurrentExecutions);
7162
}
7263

7364
@Override
7465
public void onNext(T item) {
75-
storingSubscriber.onNext(item);
76-
flushBufferIfNeeded();
77-
}
78-
79-
private void flushBufferIfNeeded() {
80-
if (isDelivering.compareAndSet(false, true)) {
81-
try {
82-
Optional<StoringSubscriber.Event<T>> next = storingSubscriber.peek();
83-
while (numRequestsInFlight.get() < maxConcurrentExecutions) {
84-
if (!next.isPresent()) {
85-
subscription.request(1);
86-
break;
87-
}
88-
89-
switch (next.get().type()) {
90-
case ON_COMPLETE:
91-
handleCompleteEvent();
92-
break;
93-
case ON_ERROR:
94-
handleError(next.get().runtimeError());
95-
break;
96-
case ON_NEXT:
97-
handleOnNext(next.get().value());
98-
break;
99-
default:
100-
handleError(new IllegalStateException("Unknown stored type: " + next.get().type()));
101-
break;
102-
}
103-
104-
next = storingSubscriber.peek();
105-
}
106-
} finally {
107-
isDelivering.set(false);
108-
}
109-
}
110-
}
111-
112-
private void handleOnNext(T item) {
113-
storingSubscriber.poll();
114-
115-
int numberOfRequestInFlight = numRequestsInFlight.incrementAndGet();
116-
log.debug(() -> "Delivering next item, numRequestInFlight=" + numberOfRequestInFlight);
117-
66+
numRequestsInFlight.incrementAndGet();
11867
consumer.apply(item).whenComplete((r, t) -> {
119-
numRequestsInFlight.decrementAndGet();
120-
if (!isStreamingDone) {
68+
checkForCompletion(numRequestsInFlight.decrementAndGet());
69+
synchronized (this) {
12170
subscription.request(1);
122-
} else {
123-
flushBufferIfNeeded();
12471
}
12572
});
12673
}
12774

128-
private void handleCompleteEvent() {
129-
if (numRequestsInFlight.get() == 0) {
130-
returnFuture.complete(null);
131-
storingSubscriber.poll();
132-
}
133-
}
134-
13575
@Override
13676
public void onError(Throwable t) {
137-
handleError(t);
138-
storingSubscriber.onError(t);
139-
}
140-
141-
private void handleError(Throwable t) {
77+
// Need to complete future exceptionally first to prevent
78+
// accidental successful completion by a concurrent checkForCompletion.
14279
returnFuture.completeExceptionally(t);
143-
storingSubscriber.poll();
80+
upstreamDone = true;
14481
}
14582

14683
@Override
14784
public void onComplete() {
148-
isStreamingDone = true;
149-
storingSubscriber.onComplete();
150-
flushBufferIfNeeded();
85+
upstreamDone = true;
86+
checkForCompletion(numRequestsInFlight.get());
87+
}
88+
89+
private void checkForCompletion(int requestsInFlight) {
90+
if (upstreamDone && requestsInFlight == 0) {
91+
// This could get invoked multiple times, but it doesn't matter
92+
// because future.complete is idempotent.
93+
returnFuture.complete(null);
94+
}
15195
}
15296

15397
/**

services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/TransferManagerLoggingTest.java

+30-20
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,52 @@
1717

1818
import static org.assertj.core.api.Assertions.assertThat;
1919

20+
import java.util.HashSet;
2021
import java.util.List;
22+
import java.util.Set;
2123
import org.apache.logging.log4j.Level;
2224
import org.apache.logging.log4j.core.LogEvent;
2325
import org.junit.jupiter.api.BeforeEach;
2426
import org.junit.jupiter.api.Test;
27+
import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
28+
import software.amazon.awssdk.regions.Region;
2529
import software.amazon.awssdk.services.s3.S3AsyncClient;
2630
import software.amazon.awssdk.testutils.LogCaptor;
2731
import software.amazon.awssdk.transfer.s3.S3TransferManager;
2832

29-
public class TransferManagerLoggingTest {
33+
class TransferManagerLoggingTest {
3034

3135
@Test
32-
public void transferManager_withCrtClient_shouldNotLogWarnMessages(){
33-
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
34-
S3AsyncClient s3Crt = S3AsyncClient.crtCreate();
35-
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build();
36+
void transferManager_withCrtClient_shouldNotLogWarnMessages() {
3637

37-
List<LogEvent> events = logCaptor.loggedEvents();
38-
assertThat(events).isEmpty();
39-
logCaptor.clear();
40-
logCaptor.close();
38+
try (S3AsyncClient s3Crt = S3AsyncClient.crtBuilder()
39+
.region(Region.US_WEST_2)
40+
.credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar"))
41+
.build();
42+
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
43+
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) {
44+
List<LogEvent> events = logCaptor.loggedEvents();
45+
assertThat(events).isEmpty();
46+
}
4147
}
4248

4349
@Test
44-
public void transferManager_withJavaClient_shouldLogWarnMessage(){
45-
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
46-
S3AsyncClient s3Java = S3AsyncClient.create();
47-
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Java).build();
50+
void transferManager_withJavaClient_shouldLogWarnMessage() {
4851

49-
List<LogEvent> events = logCaptor.loggedEvents();
50-
assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and "
51-
+ "thus multipart upload/download feature is not enabled and resumable file upload is "
52-
+ "not supported. To benefit from maximum throughput, consider using "
53-
+ "S3AsyncClient.crtBuilder().build() instead.");
54-
logCaptor.clear();
55-
logCaptor.close();
52+
53+
try (S3AsyncClient s3Crt = S3AsyncClient.builder()
54+
.region(Region.US_WEST_2)
55+
.credentialsProvider(() -> AwsBasicCredentials.create("foo", "bar"))
56+
.build();
57+
LogCaptor logCaptor = LogCaptor.create(Level.WARN);
58+
S3TransferManager tm = S3TransferManager.builder().s3Client(s3Crt).build()) {
59+
List<LogEvent> events = logCaptor.loggedEvents();
60+
assertLogged(events, Level.WARN, "The provided DefaultS3AsyncClient is not an instance of S3CrtAsyncClient, and "
61+
+ "thus multipart upload/download feature is not enabled and resumable file upload"
62+
+ " is "
63+
+ "not supported. To benefit from maximum throughput, consider using "
64+
+ "S3AsyncClient.crtBuilder().build() instead.");
65+
}
5666
}
5767

5868
private static void assertLogged(List<LogEvent> events, org.apache.logging.log4j.Level level, String message) {

0 commit comments

Comments
 (0)