diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java index a828f6e413d77..6fce79e31b911 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java @@ -17,6 +17,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -46,15 +47,11 @@ public TransportFlushAction( client, actionFilters, indexNameExpressionResolver, - TransportShardFlushAction.TYPE + TransportShardFlushAction.TYPE, + ThreadPool.Names.FLUSH ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) { return new ShardFlushRequest(request, shardId); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java index ff9f6640b4120..ceb940502da5d 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java @@ -19,6 +19,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import java.util.List; @@ -48,15 +49,11 @@ public TransportRefreshAction( client, actionFilters, indexNameExpressionResolver, - TransportShardRefreshAction.TYPE + TransportShardRefreshAction.TYPE, + ThreadPool.Names.REFRESH ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) { BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId); diff --git a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java index 358a0b4d0233a..94c90a0efcd83 100644 --- a/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java +++ b/server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java @@ -8,8 +8,8 @@ package org.elasticsearch.action.support.replication; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.DefaultShardOperationFailedException; @@ -26,14 +26,16 @@ import org.elasticsearch.cluster.routing.IndexRoutingTable; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.transport.Transports; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.CopyOnWriteArrayList; +import java.util.Map; /** * Base class for requests that should be executed on all shards of an index or several indices. @@ -49,6 +51,7 @@ public abstract class TransportBroadcastReplicationAction< private final ClusterService clusterService; private final IndexNameExpressionResolver indexNameExpressionResolver; private final NodeClient client; + private final String executor; public TransportBroadcastReplicationAction( String name, @@ -58,58 +61,112 @@ public TransportBroadcastReplicationAction( NodeClient client, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, - ActionType replicatedBroadcastShardAction + ActionType replicatedBroadcastShardAction, + String executor ) { super(name, transportService, actionFilters, requestReader); this.client = client; this.replicatedBroadcastShardAction = replicatedBroadcastShardAction; this.clusterService = clusterService; this.indexNameExpressionResolver = indexNameExpressionResolver; + this.executor = executor; } @Override protected void doExecute(Task task, Request request, ActionListener listener) { - final ClusterState clusterState = clusterService.state(); - List shards = shards(request, clusterState); - final CopyOnWriteArrayList shardsResponses = new CopyOnWriteArrayList<>(); - try (var refs = new RefCountingRunnable(() -> finishAndNotifyListener(listener, shardsResponses))) { - for (final ShardId shardId : shards) { - ActionListener shardActionListener = new ActionListener() { - @Override - public void onResponse(ShardResponse shardResponse) { - shardsResponses.add(shardResponse); - logger.trace("{}: got response from {}", actionName, shardId); + clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, createAsyncAction(task, request))); + } + + private CheckedConsumer, Exception> createAsyncAction(Task task, Request request) { + return new CheckedConsumer, Exception>() { + + private int totalShardCopyCount; + private int successShardCopyCount; + private final List allFailures = new ArrayList<>(); + + @Override + public void accept(ActionListener listener) { + assert totalShardCopyCount == 0 && successShardCopyCount == 0 && allFailures.isEmpty() : "shouldn't call this twice"; + + final ClusterState clusterState = clusterService.state(); + final List shards = shards(request, clusterState); + final Map indexMetadataByName = clusterState.getMetadata().indices(); + + try (var refs = new RefCountingRunnable(() -> finish(listener))) { + for (final ShardId shardId : shards) { + // NB This sends O(#shards) requests in a tight loop; TODO add some throttling here? + shardExecute( + task, + request, + shardId, + ActionListener.releaseAfter(new ReplicationResponseActionListener(shardId, indexMetadataByName), refs.acquire()) + ); } + } + } + + private synchronized void addShardResponse(int numCopies, int successful, List failures) { + totalShardCopyCount += numCopies; + successShardCopyCount += successful; + allFailures.addAll(failures); + } + + void finish(ActionListener listener) { + // no need for synchronized here, the RefCountingRunnable guarantees that all the addShardResponse calls happen-before here + logger.trace("{}: got all shard responses", actionName); + listener.onResponse(newResponse(successShardCopyCount, allFailures.size(), totalShardCopyCount, allFailures)); + } + + class ReplicationResponseActionListener implements ActionListener { + private final ShardId shardId; + private final Map indexMetadataByName; + + ReplicationResponseActionListener(ShardId shardId, Map indexMetadataByName) { + this.shardId = shardId; + this.indexMetadataByName = indexMetadataByName; + } - @Override - public void onFailure(Exception e) { - logger.trace("{}: got failure from {}", actionName, shardId); - int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1; - ShardResponse shardResponse = newShardResponse(); - ReplicationResponse.ShardInfo.Failure[] failures; - if (TransportActions.isShardNotAvailableException(e)) { - failures = new ReplicationResponse.ShardInfo.Failure[0]; - } else { - ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure( - shardId, - null, - e, - ExceptionsHelper.status(e), - true - ); - failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies]; - Arrays.fill(failures, failure); - } - shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures)); - shardsResponses.add(shardResponse); + @Override + public void onResponse(ShardResponse shardResponse) { + assert shardResponse != null; + logger.trace("{}: got response from {}", actionName, shardId); + addShardResponse( + shardResponse.getShardInfo().getTotal(), + shardResponse.getShardInfo().getSuccessful(), + Arrays.stream(shardResponse.getShardInfo().getFailures()) + .map( + f -> new DefaultShardOperationFailedException( + new BroadcastShardOperationFailedException(shardId, f.getCause()) + ) + ) + .toList() + ); + } + + @Override + public void onFailure(Exception e) { + logger.trace("{}: got failure from {}", actionName, shardId); + final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1; + final List result; + if (TransportActions.isShardNotAvailableException(e)) { + result = List.of(); + } else { + final var failures = new DefaultShardOperationFailedException[numCopies]; + Arrays.fill( + failures, + new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e)) + ); + result = Arrays.asList(failures); } - }; - shardExecute(task, request, shardId, ActionListener.releaseAfter(shardActionListener, refs.acquire())); + addShardResponse(numCopies, 0, result); + } } - } + + }; } protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener shardActionListener) { + assert Transports.assertNotTransportThread("may hit all the shards"); ShardRequest shardRequest = newShardRequest(request, shardId); shardRequest.setParentTask(clusterService.localNode().getId(), task.getId()); client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener); @@ -119,6 +176,7 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL * @return all shard ids the request should run on */ protected List shards(Request request, ClusterState clusterState) { + assert Transports.assertNotTransportThread("may hit all the shards"); List shardIds = new ArrayList<>(); String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request); for (String index : concreteIndices) { @@ -133,43 +191,13 @@ protected List shards(Request request, ClusterState clusterState) { return shardIds; } - protected abstract ShardResponse newShardResponse(); - protected abstract ShardRequest newShardRequest(Request request, ShardId shardId); - private void finishAndNotifyListener(ActionListener listener, CopyOnWriteArrayList shardsResponses) { - logger.trace("{}: got all shard responses", actionName); - int successfulShards = 0; - int failedShards = 0; - int totalNumCopies = 0; - List shardFailures = null; - for (int i = 0; i < shardsResponses.size(); i++) { - ReplicationResponse shardResponse = shardsResponses.get(i); - if (shardResponse == null) { - // non active shard, ignore - } else { - failedShards += shardResponse.getShardInfo().getFailed(); - successfulShards += shardResponse.getShardInfo().getSuccessful(); - totalNumCopies += shardResponse.getShardInfo().getTotal(); - if (shardFailures == null) { - shardFailures = new ArrayList<>(); - } - for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) { - shardFailures.add( - new DefaultShardOperationFailedException( - new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause()) - ) - ); - } - } - } - listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures)); - } - protected abstract Response newResponse( int successfulShards, int failedShards, int totalNumCopies, List shardFailures ); + } diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java index 51363e76d8adb..e2672bcfb67ee 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java @@ -254,15 +254,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati null, actionFilters, indexNameExpressionResolver, - null + null, + ThreadPool.Names.SAME ); } - @Override - protected ReplicationResponse newShardResponse() { - return new ReplicationResponse(); - } - @Override protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) { return new BasicReplicationRequest(shardId);