Skip to content

Commit 4aa4a0d

Browse files
authored
Improve scalability of BroadcastReplicationActions (#92902)
BroadcastReplicationAction derivatives (`POST /<indices>/_refresh` and `POST /<indices>/_flush`) are pretty inefficient when targeting high shard counts due to how `TransportBroadcastReplicationAction` works: - It computes the list of all target shards up-front on the calling (transport) thread. - It accumulates responses in a `CopyOnWriteArrayList` which takes quadratic work to populate, even though nothing reads this list until it's fully populated. - It then mostly discards the accumulated responses, keeping only the total number of shards, the number of successful shards, and a list of any failures. - Each failure is wrapped up in a `ReplicationResponse.ShardInfo.Failure` but then unwrapped at the end to be re-wrapped in a `DefaultShardOperationFailedException`. This commit fixes all this: - The computation of the list of shards, and the sending of the per-shard requests, now happens on the relevant threadpool (`REFRESH` or `FLUSH`) rather than a transport thread. - The failures are tracked in a regular `ArrayList`, avoiding the accidentally-quadratic complexity. - Rather than accumulating the full responses for later processing we track the counts and failures directly. - The failures are tracked in their final form, skipping the unwrap-and-rewrap step at the end. Relates #77466 Relates #92729
1 parent 1a9150d commit 4aa4a0d

File tree

4 files changed

+104
-86
lines changed

4 files changed

+104
-86
lines changed

server/src/main/java/org/elasticsearch/action/admin/indices/flush/TransportFlushAction.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.cluster.service.ClusterService;
1818
import org.elasticsearch.common.inject.Inject;
1919
import org.elasticsearch.index.shard.ShardId;
20+
import org.elasticsearch.threadpool.ThreadPool;
2021
import org.elasticsearch.transport.TransportService;
2122

2223
import java.util.List;
@@ -46,15 +47,11 @@ public TransportFlushAction(
4647
client,
4748
actionFilters,
4849
indexNameExpressionResolver,
49-
TransportShardFlushAction.TYPE
50+
TransportShardFlushAction.TYPE,
51+
ThreadPool.Names.FLUSH
5052
);
5153
}
5254

53-
@Override
54-
protected ReplicationResponse newShardResponse() {
55-
return new ReplicationResponse();
56-
}
57-
5855
@Override
5956
protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) {
6057
return new ShardFlushRequest(request, shardId);

server/src/main/java/org/elasticsearch/action/admin/indices/refresh/TransportRefreshAction.java

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.cluster.service.ClusterService;
2020
import org.elasticsearch.common.inject.Inject;
2121
import org.elasticsearch.index.shard.ShardId;
22+
import org.elasticsearch.threadpool.ThreadPool;
2223
import org.elasticsearch.transport.TransportService;
2324

2425
import java.util.List;
@@ -48,15 +49,11 @@ public TransportRefreshAction(
4849
client,
4950
actionFilters,
5051
indexNameExpressionResolver,
51-
TransportShardRefreshAction.TYPE
52+
TransportShardRefreshAction.TYPE,
53+
ThreadPool.Names.REFRESH
5254
);
5355
}
5456

55-
@Override
56-
protected ReplicationResponse newShardResponse() {
57-
return new ReplicationResponse();
58-
}
59-
6057
@Override
6158
protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) {
6259
BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId);

server/src/main/java/org/elasticsearch/action/support/replication/TransportBroadcastReplicationAction.java

Lines changed: 96 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88

99
package org.elasticsearch.action.support.replication;
1010

11-
import org.elasticsearch.ExceptionsHelper;
1211
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.action.ActionRunnable;
1313
import org.elasticsearch.action.ActionType;
1414
import org.elasticsearch.action.support.ActionFilters;
1515
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
@@ -26,14 +26,16 @@
2626
import org.elasticsearch.cluster.routing.IndexRoutingTable;
2727
import org.elasticsearch.cluster.service.ClusterService;
2828
import org.elasticsearch.common.io.stream.Writeable;
29+
import org.elasticsearch.core.CheckedConsumer;
2930
import org.elasticsearch.index.shard.ShardId;
3031
import org.elasticsearch.tasks.Task;
3132
import org.elasticsearch.transport.TransportService;
33+
import org.elasticsearch.transport.Transports;
3234

3335
import java.util.ArrayList;
3436
import java.util.Arrays;
3537
import java.util.List;
36-
import java.util.concurrent.CopyOnWriteArrayList;
38+
import java.util.Map;
3739

3840
/**
3941
* Base class for requests that should be executed on all shards of an index or several indices.
@@ -49,6 +51,7 @@ public abstract class TransportBroadcastReplicationAction<
4951
private final ClusterService clusterService;
5052
private final IndexNameExpressionResolver indexNameExpressionResolver;
5153
private final NodeClient client;
54+
private final String executor;
5255

5356
public TransportBroadcastReplicationAction(
5457
String name,
@@ -58,58 +61,112 @@ public TransportBroadcastReplicationAction(
5861
NodeClient client,
5962
ActionFilters actionFilters,
6063
IndexNameExpressionResolver indexNameExpressionResolver,
61-
ActionType<ShardResponse> replicatedBroadcastShardAction
64+
ActionType<ShardResponse> replicatedBroadcastShardAction,
65+
String executor
6266
) {
6367
super(name, transportService, actionFilters, requestReader);
6468
this.client = client;
6569
this.replicatedBroadcastShardAction = replicatedBroadcastShardAction;
6670
this.clusterService = clusterService;
6771
this.indexNameExpressionResolver = indexNameExpressionResolver;
72+
this.executor = executor;
6873
}
6974

7075
@Override
7176
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
72-
final ClusterState clusterState = clusterService.state();
73-
List<ShardId> shards = shards(request, clusterState);
74-
final CopyOnWriteArrayList<ShardResponse> shardsResponses = new CopyOnWriteArrayList<>();
75-
try (var refs = new RefCountingRunnable(() -> finishAndNotifyListener(listener, shardsResponses))) {
76-
for (final ShardId shardId : shards) {
77-
ActionListener<ShardResponse> shardActionListener = new ActionListener<ShardResponse>() {
78-
@Override
79-
public void onResponse(ShardResponse shardResponse) {
80-
shardsResponses.add(shardResponse);
81-
logger.trace("{}: got response from {}", actionName, shardId);
77+
clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, createAsyncAction(task, request)));
78+
}
79+
80+
private CheckedConsumer<ActionListener<Response>, Exception> createAsyncAction(Task task, Request request) {
81+
return new CheckedConsumer<ActionListener<Response>, Exception>() {
82+
83+
private int totalShardCopyCount;
84+
private int successShardCopyCount;
85+
private final List<DefaultShardOperationFailedException> allFailures = new ArrayList<>();
86+
87+
@Override
88+
public void accept(ActionListener<Response> listener) {
89+
assert totalShardCopyCount == 0 && successShardCopyCount == 0 && allFailures.isEmpty() : "shouldn't call this twice";
90+
91+
final ClusterState clusterState = clusterService.state();
92+
final List<ShardId> shards = shards(request, clusterState);
93+
final Map<String, IndexMetadata> indexMetadataByName = clusterState.getMetadata().indices();
94+
95+
try (var refs = new RefCountingRunnable(() -> finish(listener))) {
96+
for (final ShardId shardId : shards) {
97+
// NB This sends O(#shards) requests in a tight loop; TODO add some throttling here?
98+
shardExecute(
99+
task,
100+
request,
101+
shardId,
102+
ActionListener.releaseAfter(new ReplicationResponseActionListener(shardId, indexMetadataByName), refs.acquire())
103+
);
82104
}
105+
}
106+
}
107+
108+
private synchronized void addShardResponse(int numCopies, int successful, List<DefaultShardOperationFailedException> failures) {
109+
totalShardCopyCount += numCopies;
110+
successShardCopyCount += successful;
111+
allFailures.addAll(failures);
112+
}
113+
114+
void finish(ActionListener<Response> listener) {
115+
// no need for synchronized here, the RefCountingRunnable guarantees that all the addShardResponse calls happen-before here
116+
logger.trace("{}: got all shard responses", actionName);
117+
listener.onResponse(newResponse(successShardCopyCount, allFailures.size(), totalShardCopyCount, allFailures));
118+
}
119+
120+
class ReplicationResponseActionListener implements ActionListener<ShardResponse> {
121+
private final ShardId shardId;
122+
private final Map<String, IndexMetadata> indexMetadataByName;
123+
124+
ReplicationResponseActionListener(ShardId shardId, Map<String, IndexMetadata> indexMetadataByName) {
125+
this.shardId = shardId;
126+
this.indexMetadataByName = indexMetadataByName;
127+
}
83128

84-
@Override
85-
public void onFailure(Exception e) {
86-
logger.trace("{}: got failure from {}", actionName, shardId);
87-
int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1;
88-
ShardResponse shardResponse = newShardResponse();
89-
ReplicationResponse.ShardInfo.Failure[] failures;
90-
if (TransportActions.isShardNotAvailableException(e)) {
91-
failures = new ReplicationResponse.ShardInfo.Failure[0];
92-
} else {
93-
ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure(
94-
shardId,
95-
null,
96-
e,
97-
ExceptionsHelper.status(e),
98-
true
99-
);
100-
failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies];
101-
Arrays.fill(failures, failure);
102-
}
103-
shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures));
104-
shardsResponses.add(shardResponse);
129+
@Override
130+
public void onResponse(ShardResponse shardResponse) {
131+
assert shardResponse != null;
132+
logger.trace("{}: got response from {}", actionName, shardId);
133+
addShardResponse(
134+
shardResponse.getShardInfo().getTotal(),
135+
shardResponse.getShardInfo().getSuccessful(),
136+
Arrays.stream(shardResponse.getShardInfo().getFailures())
137+
.map(
138+
f -> new DefaultShardOperationFailedException(
139+
new BroadcastShardOperationFailedException(shardId, f.getCause())
140+
)
141+
)
142+
.toList()
143+
);
144+
}
145+
146+
@Override
147+
public void onFailure(Exception e) {
148+
logger.trace("{}: got failure from {}", actionName, shardId);
149+
final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1;
150+
final List<DefaultShardOperationFailedException> result;
151+
if (TransportActions.isShardNotAvailableException(e)) {
152+
result = List.of();
153+
} else {
154+
final var failures = new DefaultShardOperationFailedException[numCopies];
155+
Arrays.fill(
156+
failures,
157+
new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e))
158+
);
159+
result = Arrays.asList(failures);
105160
}
106-
};
107-
shardExecute(task, request, shardId, ActionListener.releaseAfter(shardActionListener, refs.acquire()));
161+
addShardResponse(numCopies, 0, result);
162+
}
108163
}
109-
}
164+
165+
};
110166
}
111167

112168
protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener<ShardResponse> shardActionListener) {
169+
assert Transports.assertNotTransportThread("may hit all the shards");
113170
ShardRequest shardRequest = newShardRequest(request, shardId);
114171
shardRequest.setParentTask(clusterService.localNode().getId(), task.getId());
115172
client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener);
@@ -119,6 +176,7 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL
119176
* @return all shard ids the request should run on
120177
*/
121178
protected List<ShardId> shards(Request request, ClusterState clusterState) {
179+
assert Transports.assertNotTransportThread("may hit all the shards");
122180
List<ShardId> shardIds = new ArrayList<>();
123181
String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request);
124182
for (String index : concreteIndices) {
@@ -133,43 +191,13 @@ protected List<ShardId> shards(Request request, ClusterState clusterState) {
133191
return shardIds;
134192
}
135193

136-
protected abstract ShardResponse newShardResponse();
137-
138194
protected abstract ShardRequest newShardRequest(Request request, ShardId shardId);
139195

140-
private void finishAndNotifyListener(ActionListener<Response> listener, CopyOnWriteArrayList<ShardResponse> shardsResponses) {
141-
logger.trace("{}: got all shard responses", actionName);
142-
int successfulShards = 0;
143-
int failedShards = 0;
144-
int totalNumCopies = 0;
145-
List<DefaultShardOperationFailedException> shardFailures = null;
146-
for (int i = 0; i < shardsResponses.size(); i++) {
147-
ReplicationResponse shardResponse = shardsResponses.get(i);
148-
if (shardResponse == null) {
149-
// non active shard, ignore
150-
} else {
151-
failedShards += shardResponse.getShardInfo().getFailed();
152-
successfulShards += shardResponse.getShardInfo().getSuccessful();
153-
totalNumCopies += shardResponse.getShardInfo().getTotal();
154-
if (shardFailures == null) {
155-
shardFailures = new ArrayList<>();
156-
}
157-
for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) {
158-
shardFailures.add(
159-
new DefaultShardOperationFailedException(
160-
new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause())
161-
)
162-
);
163-
}
164-
}
165-
}
166-
listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures));
167-
}
168-
169196
protected abstract Response newResponse(
170197
int successfulShards,
171198
int failedShards,
172199
int totalNumCopies,
173200
List<DefaultShardOperationFailedException> shardFailures
174201
);
202+
175203
}

server/src/test/java/org/elasticsearch/action/support/replication/BroadcastReplicationTests.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,15 +254,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati
254254
null,
255255
actionFilters,
256256
indexNameExpressionResolver,
257-
null
257+
null,
258+
ThreadPool.Names.SAME
258259
);
259260
}
260261

261-
@Override
262-
protected ReplicationResponse newShardResponse() {
263-
return new ReplicationResponse();
264-
}
265-
266262
@Override
267263
protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) {
268264
return new BasicReplicationRequest(shardId);

0 commit comments

Comments
 (0)