Skip to content

TM: downloadDirectory refactor how SDK sends concurrent download file requests #3867

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 30, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@

package software.amazon.awssdk.transfer.s3.internal;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;
import software.amazon.awssdk.annotations.SdkInternalApi;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;
import software.amazon.awssdk.utils.async.DemandIgnoringSubscription;
import software.amazon.awssdk.utils.async.StoringSubscriber;

/**
* An implementation of {@link Subscriber} that execute the provided function for every event and limits the number of concurrent
Expand All @@ -41,20 +37,16 @@ public class AsyncBufferingSubscriber<T> implements Subscriber<T> {
private final Function<T, CompletableFuture<?>> consumer;
private final int maxConcurrentExecutions;
private final AtomicInteger numRequestsInFlight;
private final AtomicBoolean isDelivering = new AtomicBoolean(false);
private volatile boolean isStreamingDone;
private volatile boolean upstreamDone;
private Subscription subscription;

private final StoringSubscriber<T> storingSubscriber;

public AsyncBufferingSubscriber(Function<T, CompletableFuture<?>> consumer,
CompletableFuture<Void> returnFuture,
int maxConcurrentExecutions) {
this.returnFuture = returnFuture;
this.consumer = consumer;
this.maxConcurrentExecutions = maxConcurrentExecutions;
this.numRequestsInFlight = new AtomicInteger(0);
this.storingSubscriber = new StoringSubscriber<>(Integer.MAX_VALUE);
}

@Override
Expand All @@ -65,89 +57,40 @@ public void onSubscribe(Subscription subscription) {
subscription.cancel();
return;
}
storingSubscriber.onSubscribe(new DemandIgnoringSubscription(subscription));
this.subscription = subscription;
subscription.request(maxConcurrentExecutions);
}

@Override
public void onNext(T item) {
storingSubscriber.onNext(item);
flushBufferIfNeeded();
}

private void flushBufferIfNeeded() {
if (isDelivering.compareAndSet(false, true)) {
try {
Optional<StoringSubscriber.Event<T>> next = storingSubscriber.peek();
while (numRequestsInFlight.get() < maxConcurrentExecutions) {
if (!next.isPresent()) {
subscription.request(1);
break;
}

switch (next.get().type()) {
case ON_COMPLETE:
handleCompleteEvent();
break;
case ON_ERROR:
handleError(next.get().runtimeError());
break;
case ON_NEXT:
handleOnNext(next.get().value());
break;
default:
handleError(new IllegalStateException("Unknown stored type: " + next.get().type()));
break;
}

next = storingSubscriber.peek();
}
} finally {
isDelivering.set(false);
}
}
}

private void handleOnNext(T item) {
storingSubscriber.poll();

int numberOfRequestInFlight = numRequestsInFlight.incrementAndGet();
log.debug(() -> "Delivering next item, numRequestInFlight=" + numberOfRequestInFlight);

numRequestsInFlight.incrementAndGet();
consumer.apply(item).whenComplete((r, t) -> {
numRequestsInFlight.decrementAndGet();
if (!isStreamingDone) {
subscription.request(1);
} else {
flushBufferIfNeeded();
}
checkForCompletion();
});
}

private void handleCompleteEvent() {
if (numRequestsInFlight.get() == 0) {
returnFuture.complete(null);
storingSubscriber.poll();
}
}

@Override
public void onError(Throwable t) {
handleError(t);
storingSubscriber.onError(t);
}

private void handleError(Throwable t) {
upstreamDone = true;
returnFuture.completeExceptionally(t);
storingSubscriber.poll();
}

@Override
public void onComplete() {
isStreamingDone = true;
storingSubscriber.onComplete();
flushBufferIfNeeded();
upstreamDone = true;
checkForCompletion();
}

private void checkForCompletion() {
if (!upstreamDone) {
subscription.request(1);
return;
}

if (upstreamDone && numRequestsInFlight.get() == 0) {
returnFuture.complete(null);
}
}

/**
Expand Down