From db37bf0bbca795f2b048ddf0e0a4ba6bef2d8073 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Tue, 6 Dec 2022 12:09:45 -0500 Subject: [PATCH 1/7] Add a CountDownActionListener It's very much like GroupedActionListener, but for when you don't care about the individual results at all. --- .../support/CountDownActionListener.java | 67 +++++++++ .../support/CountDownActionListenerTests.java | 132 ++++++++++++++++++ 2 files changed, 199 insertions(+) create mode 100644 server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java create mode 100644 server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java new file mode 100644 index 0000000000000..72d65c2e53980 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -0,0 +1,67 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.action.support; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.util.concurrent.CountDown; + +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Wraps another listener and adds a counter -- each invocation of this listener will decrement the counter, and when the counter has been + * exhausted the final invocation of this listener will delegate to the wrapped listener. Similar to {@link GroupedActionListener}, but for + * the cases where tracking individual results is not useful. + */ +public final class CountDownActionListener extends ActionListener.Delegating { + + private final CountDown countDown; + private final AtomicReference failure = new AtomicReference<>(); + + /** + * Creates a new listener + * @param groupSize the group size + * @param delegate the delegate listener + */ + public CountDownActionListener(int groupSize, ActionListener delegate) { + super(Objects.requireNonNull(delegate)); + if (groupSize <= 0) { + assert false : "illegal group size [" + groupSize + "]"; + throw new IllegalArgumentException("groupSize must be greater than 0 but was " + groupSize); + } + countDown = new CountDown(groupSize); + } + + @Override + public void onResponse(Void element) { + if (countDown.countDown()) { + if (failure.get() != null) { + super.onFailure(failure.get()); + } else { + delegate.onResponse(element); + } + } + } + + @Override + public void onFailure(Exception e) { + if (failure.compareAndSet(null, e) == false) { + failure.accumulateAndGet(e, (current, update) -> { + // we have to avoid self-suppression! + if (update != current) { + current.addSuppressed(update); + } + return current; + }); + } + if (countDown.countDown()) { + super.onFailure(failure.get()); + } + } + +} diff --git a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java new file mode 100644 index 0000000000000..e1a8d6c1447b8 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.action.support; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +public class CountDownActionListenerTests extends ESTestCase { + + public void testNotifications() throws InterruptedException { + AtomicBoolean called = new AtomicBoolean(false); + ActionListener result = new ActionListener<>() { + @Override + public void onResponse(Void ignored) { + called.set(true); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }; + final int groupSize = randomIntBetween(10, 1000); + AtomicInteger count = new AtomicInteger(); + CountDownActionListener listener = new CountDownActionListener(groupSize, result); + int numThreads = randomIntBetween(2, 5); + Thread[] threads = new Thread[numThreads]; + CyclicBarrier barrier = new CyclicBarrier(numThreads); + for (int i = 0; i < numThreads; i++) { + threads[i] = new Thread(() -> { + try { + barrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError(e); + } + while (count.incrementAndGet() <= groupSize) { + listener.onResponse(null); + } + }); + threads[i].start(); + } + for (Thread t : threads) { + t.join(); + } + assertTrue(called.get()); + } + + public void testFailed() { + AtomicBoolean called = new AtomicBoolean(false); + AtomicReference excRef = new AtomicReference<>(); + + ActionListener result = new ActionListener<>() { + @Override + public void onResponse(Void ignored) { + called.set(true); + } + + @Override + public void onFailure(Exception e) { + excRef.set(e); + } + }; + int size = randomIntBetween(3, 4); + CountDownActionListener listener = new CountDownActionListener(size, result); + listener.onResponse(null); + IOException ioException = new IOException(); + RuntimeException rtException = new RuntimeException(); + listener.onFailure(rtException); + listener.onFailure(ioException); + if (size == 4) { + listener.onResponse(null); + } + assertNotNull(excRef.get()); + assertEquals(rtException, excRef.get()); + assertEquals(1, excRef.get().getSuppressed().length); + assertEquals(ioException, excRef.get().getSuppressed()[0]); + assertFalse(called.get()); + } + + public void testConcurrentFailures() throws InterruptedException { + AtomicReference finalException = new AtomicReference<>(); + int numGroups = randomIntBetween(10, 100); + CountDownActionListener listener = new CountDownActionListener(numGroups, ActionListener.wrap(r -> {}, finalException::set)); + ExecutorService executorService = Executors.newFixedThreadPool(numGroups); + for (int i = 0; i < numGroups; i++) { + executorService.submit(() -> listener.onFailure(new IOException())); + } + + executorService.shutdown(); + executorService.awaitTermination(10, TimeUnit.SECONDS); + + Exception exception = finalException.get(); + assertNotNull(exception); + assertThat(exception, instanceOf(IOException.class)); + assertEquals(numGroups - 1, exception.getSuppressed().length); + } + + /* + * It can happen that the same exception causes a grouped listener to be notified of the failure multiple times. Since we suppress + * additional exceptions into the first exception, we have to guard against suppressing into the same exception, which could occur if we + * are notified of with the same failure multiple times. This test verifies that the guard against self-suppression remains. + */ + public void testRepeatNotificationForTheSameException() { + final AtomicReference finalException = new AtomicReference<>(); + final CountDownActionListener listener = new CountDownActionListener(2, ActionListener.wrap(r -> {}, finalException::set)); + final Exception e = new Exception(); + // repeat notification for the same exception + listener.onFailure(e); + listener.onFailure(e); + assertThat(finalException.get(), not(nullValue())); + assertThat(finalException.get(), equalTo(e)); + } +} From 497912143b5f58199082097b04b473339df84a4d Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Tue, 6 Dec 2022 12:12:15 -0500 Subject: [PATCH 2/7] Use CountDownActionListener where feasible --- .../cluster/NodeConnectionsService.java | 9 +++------ .../routing/allocation/DiskThresholdMonitor.java | 4 ++-- .../blobstore/BlobStoreRepository.java | 7 ++++--- .../elasticsearch/snapshots/RestoreService.java | 10 +++++----- .../tasks/TaskCancellationService.java | 15 ++++++++------- .../transport/RemoteClusterService.java | 6 +++--- .../blobstore/ShardSnapshotTaskRunnerTests.java | 4 ++-- .../xpack/slm/SnapshotRetentionTask.java | 8 ++++---- .../store/SearchableSnapshotDirectory.java | 14 +++++++------- 9 files changed, 38 insertions(+), 39 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java index 077771dbcc761..e203ff72e84cd 100644 --- a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java @@ -11,7 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.cluster.coordination.FollowersChecker; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; @@ -97,10 +97,7 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion) return; } - final GroupedActionListener listener = new GroupedActionListener<>( - discoveryNodes.getSize(), - ActionListener.wrap(onCompletion) - ); + final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), ActionListener.wrap(onCompletion)); final List runnables = new ArrayList<>(discoveryNodes.getSize()); synchronized (mutex) { @@ -159,7 +156,7 @@ void ensureConnections(Runnable onCompletion) { runnables.add(onCompletion); } else { logger.trace("ensureConnections: {}", targetsByNode); - final GroupedActionListener listener = new GroupedActionListener<>( + final CountDownActionListener listener = new CountDownActionListener( connectionTargets.size(), ActionListener.wrap(onCompletion) ); diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java index 60598d62df015..21633375305dd 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java @@ -11,7 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterInfo; import org.elasticsearch.cluster.ClusterState; @@ -301,7 +301,7 @@ public void onNewInfo(ClusterInfo info) { } } - final ActionListener listener = new GroupedActionListener<>(3, ActionListener.wrap(this::checkFinished)); + final ActionListener listener = new CountDownActionListener(3, ActionListener.wrap(this::checkFinished)); if (reroute) { logger.debug("rerouting shards: [{}]", explanation); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 3c536328522cf..597ff2469b748 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -27,6 +27,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.SingleResultDeduplicator; import org.elasticsearch.action.StepListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.action.support.ListenableActionFuture; import org.elasticsearch.action.support.PlainActionFuture; @@ -964,7 +965,7 @@ private void doDeleteShardSnapshots( writeUpdatedRepoDataStep.whenComplete(updatedRepoData -> { listener.onRepositoryDataWritten(updatedRepoData); // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(listener::onDone)); + final ActionListener afterCleanupsListener = new CountDownActionListener(2, ActionListener.wrap(listener::onDone)); cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, afterCleanupsListener); asyncCleanupUnlinkedShardLevelBlobs( repositoryData, @@ -978,7 +979,7 @@ private void doDeleteShardSnapshots( final RepositoryData updatedRepoData = repositoryData.removeSnapshots(snapshotIds, ShardGenerations.EMPTY); writeIndexGen(updatedRepoData, repositoryStateId, repoMetaVersion, Function.identity(), ActionListener.wrap(newRepoData -> { // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new GroupedActionListener<>(2, ActionListener.wrap(() -> { + final ActionListener afterCleanupsListener = new CountDownActionListener(2, ActionListener.wrap(() -> { listener.onRepositoryDataWritten(newRepoData); listener.onDone(); })); @@ -1414,7 +1415,7 @@ public void finalizeSnapshot(final FinalizeSnapshotContext finalizeSnapshotConte indexMetaIdentifiers = null; } - final ActionListener allMetaListener = new GroupedActionListener<>(2 + indices.size(), ActionListener.wrap(v -> { + final ActionListener allMetaListener = new CountDownActionListener(2 + indices.size(), ActionListener.wrap(v -> { final String slmPolicy = slmPolicy(snapshotInfo); final SnapshotDetails snapshotDetails = new SnapshotDetails( snapshotInfo.state(), diff --git a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java index 408163d4acaed..10ecef65c5e6f 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/RestoreService.java @@ -14,7 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.StepListener; import org.elasticsearch.action.admin.cluster.snapshots.restore.RestoreSnapshotRequest; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -525,11 +525,11 @@ static void refreshRepositoryUuids(boolean enabled, RepositoriesService reposito "refreshing repository UUIDs for repositories [{}]", repositories.stream().map(repository -> repository.getMetadata().name()).collect(Collectors.joining(",")) ); - final ActionListener groupListener = new GroupedActionListener<>( + final ActionListener countDownListener = new CountDownActionListener( repositories.size(), - new ActionListener>() { + new ActionListener() { @Override - public void onResponse(Collection ignored) { + public void onResponse(Void ignored) { logger.debug("repository UUID refresh completed"); refreshListener.onResponse(null); } @@ -543,7 +543,7 @@ public void onFailure(Exception e) { ).map(repositoryData -> null /* don't collect the RepositoryData */); for (Repository repository : repositories) { - repository.getRepositoryData(groupListener); + repository.getRepositoryData(countDownListener); } } diff --git a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java index 788ae17f2bfbe..96e02bfa4f50f 100644 --- a/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java +++ b/server/src/main/java/org/elasticsearch/tasks/TaskCancellationService.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.ResultDeduplicator; import org.elasticsearch.action.StepListener; import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.GroupedActionListener; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -99,18 +100,18 @@ void doCancelTaskAndDescendants(CancellableTask task, String reason, boolean wai if (task.shouldCancelChildrenOnCancellation()) { logger.trace("cancelling task [{}] and its descendants", taskId); StepListener completedListener = new StepListener<>(); - GroupedActionListener groupedListener = new GroupedActionListener<>(3, completedListener.map(r -> null)); + CountDownActionListener countDownListener = new CountDownActionListener(3, completedListener); Collection childConnections = taskManager.startBanOnChildTasks(task.getId(), reason, () -> { logger.trace("child tasks of parent [{}] are completed", taskId); - groupedListener.onResponse(null); + countDownListener.onResponse(null); }); taskManager.cancel(task, reason, () -> { logger.trace("task [{}] is cancelled", taskId); - groupedListener.onResponse(null); + countDownListener.onResponse(null); }); StepListener setBanListener = new StepListener<>(); setBanOnChildConnections(reason, waitForCompletion, task, childConnections, setBanListener); - setBanListener.addListener(groupedListener); + setBanListener.addListener(countDownListener); // If we start unbanning when the last child task completed and that child task executed with a specific user, then unban // requests are denied because internal requests can't run with a user. We need to remove bans with the current thread context. final Runnable removeBansRunnable = transportService.getThreadPool() @@ -149,7 +150,7 @@ private void setBanOnChildConnections( } final TaskId taskId = new TaskId(localNodeId(), task.getId()); logger.trace("cancelling child tasks of [{}] on child connections {}", taskId, childConnections); - GroupedActionListener groupedListener = new GroupedActionListener<>(childConnections.size(), listener.map(r -> null)); + CountDownActionListener countDownListener = new CountDownActionListener(childConnections.size(), listener); final BanParentTaskRequest banRequest = BanParentTaskRequest.createSetBanParentTaskRequest(taskId, reason, waitForCompletion); for (Transport.Connection connection : childConnections) { assert TransportService.unwrapConnection(connection) == connection : "Child connection must be unwrapped"; @@ -162,7 +163,7 @@ private void setBanOnChildConnections( @Override public void handleResponse(TransportResponse.Empty response) { logger.trace("sent ban for tasks with the parent [{}] for connection [{}]", taskId, connection); - groupedListener.onResponse(null); + countDownListener.onResponse(null); } @Override @@ -188,7 +189,7 @@ public void handleException(TransportException exp) { ); } - groupedListener.onFailure(exp); + countDownListener.onFailure(exp); } } ); diff --git a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java index 696028f9dc5bd..bf380df5bbf77 100644 --- a/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java +++ b/server/src/main/java/org/elasticsearch/transport/RemoteClusterService.java @@ -12,7 +12,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.OriginalIndices; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.client.internal.Client; @@ -332,14 +332,14 @@ synchronized void updateRemoteCluster(String clusterAlias, Settings newSettings, */ void initializeRemoteClusters() { final TimeValue timeValue = REMOTE_INITIAL_CONNECTION_TIMEOUT_SETTING.get(settings); - final PlainActionFuture> future = new PlainActionFuture<>(); + final PlainActionFuture future = new PlainActionFuture<>(); Set enabledClusters = RemoteClusterAware.getEnabledRemoteClusters(settings); if (enabledClusters.isEmpty()) { return; } - GroupedActionListener listener = new GroupedActionListener<>(enabledClusters.size(), future); + CountDownActionListener listener = new CountDownActionListener(enabledClusters.size(), future); for (String clusterAlias : enabledClusters) { updateRemoteCluster(clusterAlias, settings, listener); } diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java index 6cecbd5403010..93b5ba26e1eca 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java @@ -11,7 +11,7 @@ import org.apache.lucene.store.ByteBuffersDirectory; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Settings; @@ -72,7 +72,7 @@ public void snapshotShard(SnapshotShardContext context) { finishedShardSnapshots.incrementAndGet(); } else { expectedFileSnapshotTasks.addAndGet(filesToUpload); - ActionListener uploadListener = new GroupedActionListener<>( + ActionListener uploadListener = new CountDownActionListener( filesToUpload, ActionListener.wrap(finishedShardSnapshots::incrementAndGet) ); diff --git a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/slm/SnapshotRetentionTask.java b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/slm/SnapshotRetentionTask.java index ad5eacb07b3e6..b1f8c11ba9e72 100644 --- a/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/slm/SnapshotRetentionTask.java +++ b/x-pack/plugin/ilm/src/main/java/org/elasticsearch/xpack/slm/SnapshotRetentionTask.java @@ -10,7 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.OriginSettingClient; @@ -328,9 +328,9 @@ void deleteSnapshots( long startTime = nowNanoSupplier.getAsLong(); final AtomicInteger deleted = new AtomicInteger(0); final AtomicInteger failed = new AtomicInteger(0); - final GroupedActionListener allDeletesListener = new GroupedActionListener<>( + final CountDownActionListener allDeletesListener = new CountDownActionListener( snapshotsToDelete.size(), - ActionListener.runAfter(listener.map(v -> null), () -> { + ActionListener.runAfter(listener, () -> { TimeValue totalElapsedTime = TimeValue.timeValueNanos(nowNanoSupplier.getAsLong() - startTime); logger.debug("total elapsed time for deletion of [{}] snapshots: {}", deleted, totalElapsedTime); slmStats.deletionTime(totalElapsedTime); @@ -354,7 +354,7 @@ private void deleteSnapshots( ActionListener listener ) { - final ActionListener allDeletesListener = new GroupedActionListener<>(snapshots.size(), listener.map(v -> null)); + final ActionListener allDeletesListener = new CountDownActionListener(snapshots.size(), listener); for (Tuple info : snapshots) { final SnapshotId snapshotId = info.v1(); if (runningDeletions.add(snapshotId) == false) { diff --git a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectory.java b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectory.java index 3dfe1ed53df2a..c0c6fca08b369 100644 --- a/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectory.java +++ b/x-pack/plugin/searchable-snapshots/src/main/java/org/elasticsearch/xpack/searchablesnapshots/store/SearchableSnapshotDirectory.java @@ -20,7 +20,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.StepListener; -import org.elasticsearch.action.support.GroupedActionListener; +import org.elasticsearch.action.support.CountDownActionListener; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.routing.RecoverySource; import org.elasticsearch.common.blobstore.BlobContainer; @@ -484,9 +484,9 @@ private void prewarmCache(ActionListener listener) { final BlockingQueue, CheckedRunnable>> queue = new LinkedBlockingQueue<>(); final Executor executor = prewarmExecutor(); - final GroupedActionListener completionListener = new GroupedActionListener<>( + final CountDownActionListener completionListener = new CountDownActionListener( snapshot().totalFileCount(), - ActionListener.wrap(voids -> { + ActionListener.wrap(ignored -> { recoveryState.setPreWarmComplete(); listener.onResponse(null); }, listener::onFailure) @@ -510,9 +510,9 @@ private void prewarmCache(ActionListener listener) { assert input instanceof CachedBlobContainerIndexInput : "expected cached index input but got " + input.getClass(); final int numberOfParts = file.numberOfParts(); - final StepListener> fileCompletionListener = new StepListener<>(); - fileCompletionListener.addListener(completionListener.map(voids -> null)); - fileCompletionListener.whenComplete(voids -> { + final StepListener fileCompletionListener = new StepListener<>(); + fileCompletionListener.addListener(completionListener); + fileCompletionListener.whenComplete(ignored -> { logger.debug("{} file [{}] prewarmed", shardId, file.physicalName()); input.close(); }, e -> { @@ -520,7 +520,7 @@ private void prewarmCache(ActionListener listener) { IOUtils.closeWhileHandlingException(input); }); - final GroupedActionListener partsListener = new GroupedActionListener<>(numberOfParts, fileCompletionListener); + final CountDownActionListener partsListener = new CountDownActionListener(numberOfParts, fileCompletionListener); submitted = true; for (int p = 0; p < numberOfParts; p++) { final int part = p; From f6a5dbf7e52f114fcc7e9b4daf5df782c88acc74 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Tue, 6 Dec 2022 12:17:12 -0500 Subject: [PATCH 3/7] Add a convenience constructor that wraps a Runnable It's a common enough calling pattern that it seems handy to add special support for it. --- .../action/support/CountDownActionListener.java | 9 +++++++++ .../elasticsearch/cluster/NodeConnectionsService.java | 7 ++----- .../cluster/routing/allocation/DiskThresholdMonitor.java | 2 +- .../repositories/blobstore/BlobStoreRepository.java | 6 +++--- .../blobstore/ShardSnapshotTaskRunnerTests.java | 5 +---- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java index 72d65c2e53980..31acf9f5a44c1 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -37,6 +37,15 @@ public CountDownActionListener(int groupSize, ActionListener delegate) { countDown = new CountDown(groupSize); } + /** + * Creates a new listener + * @param groupSize the group size + * @param runnable the runnable + */ + public CountDownActionListener(int groupSize, Runnable runnable) { + this(groupSize, ActionListener.wrap(Objects.requireNonNull(runnable))); + } + @Override public void onResponse(Void element) { if (countDown.countDown()) { diff --git a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java index e203ff72e84cd..2e67288358c2b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java +++ b/server/src/main/java/org/elasticsearch/cluster/NodeConnectionsService.java @@ -97,7 +97,7 @@ public void connectToNodes(DiscoveryNodes discoveryNodes, Runnable onCompletion) return; } - final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), ActionListener.wrap(onCompletion)); + final CountDownActionListener listener = new CountDownActionListener(discoveryNodes.getSize(), onCompletion); final List runnables = new ArrayList<>(discoveryNodes.getSize()); synchronized (mutex) { @@ -156,10 +156,7 @@ void ensureConnections(Runnable onCompletion) { runnables.add(onCompletion); } else { logger.trace("ensureConnections: {}", targetsByNode); - final CountDownActionListener listener = new CountDownActionListener( - connectionTargets.size(), - ActionListener.wrap(onCompletion) - ); + final CountDownActionListener listener = new CountDownActionListener(connectionTargets.size(), onCompletion); for (final ConnectionTarget connectionTarget : connectionTargets) { runnables.add(connectionTarget.connect(listener)); } diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java index 21633375305dd..82413fe3723ae 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/allocation/DiskThresholdMonitor.java @@ -301,7 +301,7 @@ public void onNewInfo(ClusterInfo info) { } } - final ActionListener listener = new CountDownActionListener(3, ActionListener.wrap(this::checkFinished)); + final ActionListener listener = new CountDownActionListener(3, this::checkFinished); if (reroute) { logger.debug("rerouting shards: [{}]", explanation); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 597ff2469b748..079a21f341e2d 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -965,7 +965,7 @@ private void doDeleteShardSnapshots( writeUpdatedRepoDataStep.whenComplete(updatedRepoData -> { listener.onRepositoryDataWritten(updatedRepoData); // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new CountDownActionListener(2, ActionListener.wrap(listener::onDone)); + final ActionListener afterCleanupsListener = new CountDownActionListener(2, listener::onDone); cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, updatedRepoData, afterCleanupsListener); asyncCleanupUnlinkedShardLevelBlobs( repositoryData, @@ -979,10 +979,10 @@ private void doDeleteShardSnapshots( final RepositoryData updatedRepoData = repositoryData.removeSnapshots(snapshotIds, ShardGenerations.EMPTY); writeIndexGen(updatedRepoData, repositoryStateId, repoMetaVersion, Function.identity(), ActionListener.wrap(newRepoData -> { // Run unreferenced blobs cleanup in parallel to shard-level snapshot deletion - final ActionListener afterCleanupsListener = new CountDownActionListener(2, ActionListener.wrap(() -> { + final ActionListener afterCleanupsListener = new CountDownActionListener(2, () -> { listener.onRepositoryDataWritten(newRepoData); listener.onDone(); - })); + }); cleanupUnlinkedRootAndIndicesBlobs(snapshotIds, foundIndices, rootBlobs, newRepoData, afterCleanupsListener); final StepListener> writeMetaAndComputeDeletesStep = new StepListener<>(); writeUpdatedShardMetaDataAndComputeDeletes(snapshotIds, repositoryData, false, writeMetaAndComputeDeletesStep); diff --git a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java index 93b5ba26e1eca..10038993f4c74 100644 --- a/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java +++ b/server/src/test/java/org/elasticsearch/repositories/blobstore/ShardSnapshotTaskRunnerTests.java @@ -72,10 +72,7 @@ public void snapshotShard(SnapshotShardContext context) { finishedShardSnapshots.incrementAndGet(); } else { expectedFileSnapshotTasks.addAndGet(filesToUpload); - ActionListener uploadListener = new CountDownActionListener( - filesToUpload, - ActionListener.wrap(finishedShardSnapshots::incrementAndGet) - ); + ActionListener uploadListener = new CountDownActionListener(filesToUpload, finishedShardSnapshots::incrementAndGet); for (int i = 0; i < filesToUpload; i++) { taskRunner.enqueueFileSnapshot(context, ShardSnapshotTaskRunnerTests::dummyFileInfo, uploadListener); } From b3595d07db84879e48a83efd254b682cea455515 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 12 Dec 2022 10:18:57 -0500 Subject: [PATCH 4/7] Add stricter validation We don't want to allow too many invocations of the listener --- .../elasticsearch/action/support/CountDownActionListener.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java index 31acf9f5a44c1..99b3f943a9684 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -48,6 +48,7 @@ public CountDownActionListener(int groupSize, Runnable runnable) { @Override public void onResponse(Void element) { + assert countDown.isCountedDown() == false; if (countDown.countDown()) { if (failure.get() != null) { super.onFailure(failure.get()); @@ -59,6 +60,7 @@ public void onResponse(Void element) { @Override public void onFailure(Exception e) { + assert countDown.isCountedDown() == false; if (failure.compareAndSet(null, e) == false) { failure.accumulateAndGet(e, (current, update) -> { // we have to avoid self-suppression! From 40a10eedca7397450ecd103564139b93c5cb31db Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 12 Dec 2022 11:07:42 -0500 Subject: [PATCH 5/7] Rewrite in terms of AtomicInteger So this is inspired by CountDown, but doesn't *use* CountDown. Crucially, this gives us the additional validation that you don't over-invoke the listener -- you get groupSize invocations, no more. --- .../support/CountDownActionListener.java | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java index 99b3f943a9684..2459482903363 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -8,9 +8,9 @@ package org.elasticsearch.action.support; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.common.util.concurrent.CountDown; import java.util.Objects; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; /** @@ -20,7 +20,7 @@ */ public final class CountDownActionListener extends ActionListener.Delegating { - private final CountDown countDown; + private final AtomicInteger countDown; private final AtomicReference failure = new AtomicReference<>(); /** @@ -34,7 +34,7 @@ public CountDownActionListener(int groupSize, ActionListener delegate) { assert false : "illegal group size [" + groupSize + "]"; throw new IllegalArgumentException("groupSize must be greater than 0 but was " + groupSize); } - countDown = new CountDown(groupSize); + countDown = new AtomicInteger(groupSize); } /** @@ -46,10 +46,18 @@ public CountDownActionListener(int groupSize, Runnable runnable) { this(groupSize, ActionListener.wrap(Objects.requireNonNull(runnable))); } + private boolean countDown() { + return countDown.updateAndGet(current -> { + if (current <= 0) { + throw new IllegalStateException("over-decrementing of count down, listener invoked too many times"); + } + return current - 1; + }) == 0; + } + @Override public void onResponse(Void element) { - assert countDown.isCountedDown() == false; - if (countDown.countDown()) { + if (countDown()) { if (failure.get() != null) { super.onFailure(failure.get()); } else { @@ -60,7 +68,6 @@ public void onResponse(Void element) { @Override public void onFailure(Exception e) { - assert countDown.isCountedDown() == false; if (failure.compareAndSet(null, e) == false) { failure.accumulateAndGet(e, (current, update) -> { // we have to avoid self-suppression! @@ -70,7 +77,7 @@ public void onFailure(Exception e) { return current; }); } - if (countDown.countDown()) { + if (countDown()) { super.onFailure(failure.get()); } } From 4cf5d952885acf701c5bb7c190fc7a92455726f9 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Mon, 12 Dec 2022 11:08:56 -0500 Subject: [PATCH 6/7] Add validation tests --- .../support/CountDownActionListenerTests.java | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java index e1a8d6c1447b8..10f0cea63e24b 100644 --- a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java @@ -96,6 +96,64 @@ public void onFailure(Exception e) { assertFalse(called.get()); } + public void testValidation() throws InterruptedException { + AtomicBoolean called = new AtomicBoolean(false); + ActionListener result = new ActionListener<>() { + @Override + public void onResponse(Void ignored) { + called.compareAndSet(false, true); + } + + @Override + public void onFailure(Exception e) { + called.compareAndSet(false, true); + } + }; + + // can't use a groupSize of 0 + expectThrows(AssertionError.class, () -> new CountDownActionListener(0, result)); + + // can't use a null listener or runnable + expectThrows(NullPointerException.class, () -> new CountDownActionListener(1, (ActionListener) null)); + expectThrows(NullPointerException.class, () -> new CountDownActionListener(1, (Runnable) null)); + + final int overage = randomIntBetween(1, 10); + AtomicInteger exceptionsThrown = new AtomicInteger(); + final int groupSize = randomIntBetween(10, 1000); + AtomicInteger count = new AtomicInteger(); + CountDownActionListener listener = new CountDownActionListener(groupSize, result); + int numThreads = randomIntBetween(2, 5); + Thread[] threads = new Thread[numThreads]; + CyclicBarrier barrier = new CyclicBarrier(numThreads); + for (int i = 0; i < numThreads; i++) { + threads[i] = new Thread(() -> { + try { + barrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError(e); + } + int c; + while ((c = count.incrementAndGet()) <= groupSize + overage) { + try { + if (c % 10 == 1) { // a mix of failures and non-failures + listener.onFailure(new RuntimeException()); + } else { + listener.onResponse(null); + } + } catch (IllegalStateException e) { + exceptionsThrown.incrementAndGet(); + } + } + }); + threads[i].start(); + } + for (Thread t : threads) { + t.join(); + } + assertTrue(called.get()); + assertEquals(overage, exceptionsThrown.get()); + } + public void testConcurrentFailures() throws InterruptedException { AtomicReference finalException = new AtomicReference<>(); int numGroups = randomIntBetween(10, 100); From a29aae960bb657c6221cbf19e9a5f8ab136a4923 Mon Sep 17 00:00:00 2001 From: Joe Gallo Date: Tue, 13 Dec 2022 09:00:58 -0500 Subject: [PATCH 7/7] Prefer assertions over exceptions for over-invoking --- .../action/support/CountDownActionListener.java | 9 +++------ .../action/support/CountDownActionListenerTests.java | 8 ++++---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java index 2459482903363..e9da843d34c25 100644 --- a/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java +++ b/server/src/main/java/org/elasticsearch/action/support/CountDownActionListener.java @@ -47,12 +47,9 @@ public CountDownActionListener(int groupSize, Runnable runnable) { } private boolean countDown() { - return countDown.updateAndGet(current -> { - if (current <= 0) { - throw new IllegalStateException("over-decrementing of count down, listener invoked too many times"); - } - return current - 1; - }) == 0; + final var result = countDown.getAndUpdate(current -> Math.max(0, current - 1)); + assert result > 0; + return result == 1; } @Override diff --git a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java index 10f0cea63e24b..7655c2fd172f4 100644 --- a/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/CountDownActionListenerTests.java @@ -118,7 +118,7 @@ public void onFailure(Exception e) { expectThrows(NullPointerException.class, () -> new CountDownActionListener(1, (Runnable) null)); final int overage = randomIntBetween(1, 10); - AtomicInteger exceptionsThrown = new AtomicInteger(); + AtomicInteger assertionsTriggered = new AtomicInteger(); final int groupSize = randomIntBetween(10, 1000); AtomicInteger count = new AtomicInteger(); CountDownActionListener listener = new CountDownActionListener(groupSize, result); @@ -140,8 +140,8 @@ public void onFailure(Exception e) { } else { listener.onResponse(null); } - } catch (IllegalStateException e) { - exceptionsThrown.incrementAndGet(); + } catch (AssertionError e) { + assertionsTriggered.incrementAndGet(); } } }); @@ -151,7 +151,7 @@ public void onFailure(Exception e) { t.join(); } assertTrue(called.get()); - assertEquals(overage, exceptionsThrown.get()); + assertEquals(overage, assertionsTriggered.get()); } public void testConcurrentFailures() throws InterruptedException {