Skip to content

Add CountDownActionListener #92308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Dec 19, 2022
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/
package org.elasticsearch.action.support;

import org.elasticsearch.action.ActionListener;

import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

/**
* Wraps another listener and adds a counter -- each invocation of this listener will decrement the counter, and when the counter has been
* exhausted the final invocation of this listener will delegate to the wrapped listener. Similar to {@link GroupedActionListener}, but for
* the cases where tracking individual results is not useful.
*/
public final class CountDownActionListener extends ActionListener.Delegating<Void, Void> {

private final AtomicInteger countDown;
private final AtomicReference<Exception> failure = new AtomicReference<>();

/**
* Creates a new listener
* @param groupSize the group size
* @param delegate the delegate listener
*/
public CountDownActionListener(int groupSize, ActionListener<Void> delegate) {
super(Objects.requireNonNull(delegate));
if (groupSize <= 0) {
assert false : "illegal group size [" + groupSize + "]";
throw new IllegalArgumentException("groupSize must be greater than 0 but was " + groupSize);
}
countDown = new AtomicInteger(groupSize);
}

/**
* Creates a new listener
* @param groupSize the group size
* @param runnable the runnable
*/
public CountDownActionListener(int groupSize, Runnable runnable) {
this(groupSize, ActionListener.wrap(Objects.requireNonNull(runnable)));
}

private boolean countDown() {
final var result = countDown.getAndUpdate(current -> Math.max(0, current - 1));
assert result > 0;
return result == 1;
}

@Override
public void onResponse(Void element) {
if (countDown()) {
if (failure.get() != null) {
super.onFailure(failure.get());
} else {
delegate.onResponse(element);
}
}
}

@Override
public void onFailure(Exception e) {
if (failure.compareAndSet(null, e) == false) {
failure.accumulateAndGet(e, (current, update) -> {
// we have to avoid self-suppression!
if (update != current) {
current.addSuppressed(update);
}
return current;
});
}
if (countDown()) {
super.onFailure(failure.get());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.cluster.coordination.FollowersChecker;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
Expand Down Expand Up @@ -97,10 +97,7 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion)
return;
}

final GroupedActionListener<Void> listener = new GroupedActionListener<>(
discoveryNodes.getSize(),
ActionListener.wrap(onCompletion)
);
final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), onCompletion);

final List<Runnable> runnables = new ArrayList<>(discoveryNodes.getSize());
synchronized (mutex) {
Expand Down Expand Up @@ -159,10 +156,7 @@ void ensureConnections(Runnable onCompletion) {
runnables.add(onCompletion);
} else {
logger.trace("ensureConnections: {}", targetsByNode);
final GroupedActionListener<Void> listener = new GroupedActionListener<>(
connectionTargets.size(),
ActionListener.wrap(onCompletion)
);
final CountDownActionListener listener = new CountDownActionListener(connectionTargets.size(), onCompletion);
for (final ConnectionTarget connectionTarget : connectionTargets) {
runnables.add(connectionTarget.connect(listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.cluster.ClusterInfo;
import org.elasticsearch.cluster.ClusterState;
Expand Down Expand Up @@ -301,7 +301,7 @@ public void onNewInfo(ClusterInfo info) {
}
}

final ActionListener<Void> listener = new GroupedActionListener<>(3, ActionListener.wrap(this::checkFinished));
final ActionListener<Void> listener = new CountDownActionListener(3, this::checkFinished);

if (reroute) {
logger.debug("rerouting shards: [{}]", explanation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.SingleResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.ListenableActionFuture;
import org.elasticsearch.action.support.PlainActionFuture;
Expand Down Expand Up @@ -964,7 +965,7 @@ private void doDeleteShardSnapshots(
writeUpdatedRepoDataStep.whenComplete(updatedRepoData -> {
listener.onRepositoryDataWritten(updatedRepoData);
// Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion
final ActionListener<Void> afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(listener::onDone));
final ActionListener<Void> afterCleanupsListener = new CountDownActionListener(2, listener::onDone);
cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, afterCleanupsListener);
asyncCleanupUnlinkedShardLevelBlobs(
repositoryData,
Expand All @@ -978,10 +979,10 @@ private void doDeleteShardSnapshots(
final RepositoryData updatedRepoData = repositoryData.removeSnapshots(snapshotIds, ShardGenerations.EMPTY);
writeIndexGen(updatedRepoData, repositoryStateId, repoMetaVersion, Function.identity(), ActionListener.wrap(newRepoData -> {
// Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion
final ActionListener<Void> afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(() -> {
final ActionListener<Void> afterCleanupsListener = new CountDownActionListener(2, () -> {
listener.onRepositoryDataWritten(newRepoData);
listener.onDone();
}));
});
cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, newRepoData, afterCleanupsListener);
final StepListener<Collection<ShardSnapshotMetaDeleteResult>> writeMetaAndComputeDeletesStep = new StepListener<>();
writeUpdatedShardMetaDataAndComputeDeletes(snapshotIds, repositoryData, false, writeMetaAndComputeDeletesStep);
Expand Down Expand Up @@ -1414,7 +1415,7 @@ public void finalizeSnapshot(final FinalizeSnapshotContext finalizeSnapshotConte
indexMetaIdentifiers = null;
}

final ActionListener<Void> allMetaListener = new GroupedActionListener<>(2 + indices.size(), ActionListener.wrap(v -> {
final ActionListener<Void> allMetaListener = new CountDownActionListener(2 + indices.size(), ActionListener.wrap(v -> {
final String slmPolicy = slmPolicy(snapshotInfo);
final SnapshotDetails snapshotDetails = new SnapshotDetails(
snapshotInfo.state(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.ClusterState;
Expand Down Expand Up @@ -525,11 +525,11 @@ static void refreshRepositoryUuids(boolean enabled, RepositoriesService reposito
"refreshing repository UUIDs for repositories [{}]",
repositories.stream().map(repository -> repository.getMetadata().name()).collect(Collectors.joining(","))
);
final ActionListener<RepositoryData> groupListener = new GroupedActionListener<>(
final ActionListener<RepositoryData> countDownListener = new CountDownActionListener(
repositories.size(),
new ActionListener<Collection<Void>>() {
new ActionListener<Void>() {
@Override
public void onResponse(Collection<Void> ignored) {
public void onResponse(Void ignored) {
logger.debug("repository UUID refresh completed");
refreshListener.onResponse(null);
}
Expand All @@ -543,7 +543,7 @@ public void onFailure(Exception e) {
).map(repositoryData -> null /* don't collect the RepositoryData */);

for (Repository repository : repositories) {
repository.getRepositoryData(groupListener);
repository.getRepositoryData(countDownListener);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.ResultDeduplicator;
import org.elasticsearch.action.StepListener;
import org.elasticsearch.action.support.ChannelActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -99,18 +100,18 @@ void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean wai
if (task.shouldCancelChildrenOnCancellation()) {
logger.trace("cancelling task [{}] and its descendants", taskId);
StepListener<Void> completedListener = new StepListener<>();
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(3, completedListener.map(r -> null));
CountDownActionListener countDownListener = new CountDownActionListener(3, completedListener);
Collection<Transport.Connection> childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> {
logger.trace("child tasks of parent [{}] are completed", taskId);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
});
taskManager.cancel(task, reason, () -> {
logger.trace("task [{}] is cancelled", taskId);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
});
StepListener<Void> setBanListener = new StepListener<>();
setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener);
setBanListener.addListener(groupedListener);
setBanListener.addListener(countDownListener);
// If we start unbanning when the last child task completed and that child task executed with a specific user, then unban
// requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context.
final Runnable removeBansRunnable = transportService.getThreadPool()
Expand Down Expand Up @@ -149,7 +150,7 @@ private void setBanOnChildConnections(
}
final TaskId taskId = new TaskId(localNodeId(), task.getId());
logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections);
GroupedActionListener<Void> groupedListener = new GroupedActionListener<>(childConnections.size(), listener.map(r -> null));
CountDownActionListener countDownListener = new CountDownActionListener(childConnections.size(), listener);
final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion);
for (Transport.Connection connection : childConnections) {
assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped";
Expand All @@ -162,7 +163,7 @@ private void setBanOnChildConnections(
@Override
public void handleResponse(TransportResponse.Empty response) {
logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection);
groupedListener.onResponse(null);
countDownListener.onResponse(null);
}

@Override
Expand All @@ -188,7 +189,7 @@ public void handleException(TransportException exp) {
);
}

groupedListener.onFailure(exp);
countDownListener.onFailure(exp);
}
}
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.OriginalIndices;
import org.elasticsearch.action.support.GroupedActionListener;
import org.elasticsearch.action.support.CountDownActionListener;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
Expand Down Expand Up @@ -332,14 +332,14 @@ synchronized void updateRemoteCluster(String clusterAlias, Settings newSettings,
*/
void initializeRemoteClusters() {
final TimeValue timeValue = REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings);
final PlainActionFuture<Collection<Void>> future = new PlainActionFuture<>();
final PlainActionFuture<Void> future = new PlainActionFuture<>();
Set<String> enabledClusters = RemoteClusterAware.getEnabledRemoteClusters(settings);

if (enabledClusters.isEmpty()) {
return;
}

GroupedActionListener<Void> listener = new GroupedActionListener<>(enabledClusters.size(), future);
CountDownActionListener listener = new CountDownActionListener(enabledClusters.size(), future);
for (String clusterAlias : enabledClusters) {
updateRemoteCluster(clusterAlias, settings, listener);
}
Expand Down
Loading