Skip to content

Improve scalability of BroadcastReplicationActions #92902

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -46,15 +47,11 @@ public TransportFlushAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardFlushAction.TYPE
TransportShardFlushAction.TYPE,
ThreadPool.Names.FLUSH
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected ShardFlushRequest newShardRequest(FlushRequest request, ShardId shardId) {
return new ShardFlushRequest(request, shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;

import java.util.List;
Expand Down Expand Up @@ -48,15 +49,11 @@ public TransportRefreshAction(
client,
actionFilters,
indexNameExpressionResolver,
TransportShardRefreshAction.TYPE
TransportShardRefreshAction.TYPE,
ThreadPool.Names.REFRESH
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected BasicReplicationRequest newShardRequest(RefreshRequest request, ShardId shardId) {
BasicReplicationRequest replicationRequest = new BasicReplicationRequest(shardId);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

package org.elasticsearch.action.support.replication;

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.ActionType;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.DefaultShardOperationFailedException;
Expand All @@ -26,14 +26,16 @@
import org.elasticsearch.cluster.routing.IndexRoutingTable;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.Transports;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.Map;

/**
* Base class for requests that should be executed on all shards of an index or several indices.
Expand All @@ -49,6 +51,7 @@ public abstract class TransportBroadcastReplicationAction<
private final ClusterService clusterService;
private final IndexNameExpressionResolver indexNameExpressionResolver;
private final NodeClient client;
private final String executor;

public TransportBroadcastReplicationAction(
String name,
Expand All @@ -58,58 +61,111 @@ public TransportBroadcastReplicationAction(
NodeClient client,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
ActionType<ShardResponse> replicatedBroadcastShardAction
ActionType<ShardResponse> replicatedBroadcastShardAction,
String executor
) {
super(name, transportService, actionFilters, requestReader);
this.client = client;
this.replicatedBroadcastShardAction = replicatedBroadcastShardAction;
this.clusterService = clusterService;
this.indexNameExpressionResolver = indexNameExpressionResolver;
this.executor = executor;
}

@Override
protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
final ClusterState clusterState = clusterService.state();
List<ShardId> shards = shards(request, clusterState);
final CopyOnWriteArrayList<ShardResponse> shardsResponses = new CopyOnWriteArrayList<>();
try (var refs = new RefCountingRunnable(() -> finishAndNotifyListener(listener, shardsResponses))) {
for (final ShardId shardId : shards) {
ActionListener<ShardResponse> shardActionListener = new ActionListener<ShardResponse>() {
@Override
public void onResponse(ShardResponse shardResponse) {
shardsResponses.add(shardResponse);
logger.trace("{}: got response from {}", actionName, shardId);
clusterService.threadPool().executor(executor).execute(ActionRunnable.wrap(listener, createAsyncAction(task, request)));
}

private CheckedConsumer<ActionListener<Response>, Exception> createAsyncAction(Task task, Request request) {
return new CheckedConsumer<ActionListener<Response>, Exception>() {

private int totalShardCopyCount;
private int successShardCopyCount;
private final List<DefaultShardOperationFailedException> allFailures = new ArrayList<>();

@Override
public void accept(ActionListener<Response> listener) {
assert totalShardCopyCount == 0 && successShardCopyCount == 0 && allFailures.isEmpty() : "shouldn't call this twice";

final ClusterState clusterState = clusterService.state();
final List<ShardId> shards = shards(request, clusterState);
final Map<String, IndexMetadata> indexMetadataByName = clusterState.getMetadata().indices();

try (var refs = new RefCountingRunnable(() -> finish(listener))) {
for (final ShardId shardId : shards) {
// NB This sends O(#shards) requests in a tight loop; TODO add some throttling here?
shardExecute(
task,
request,
shardId,
ActionListener.releaseAfter(new ReplicationResponseActionListener(shardId, indexMetadataByName), refs.acquire())
);
}
}
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
int totalNumCopies = clusterState.getMetadata().getIndexSafe(shardId.getIndex()).getNumberOfReplicas() + 1;
ShardResponse shardResponse = newShardResponse();
ReplicationResponse.ShardInfo.Failure[] failures;
if (TransportActions.isShardNotAvailableException(e)) {
failures = new ReplicationResponse.ShardInfo.Failure[0];
} else {
ReplicationResponse.ShardInfo.Failure failure = new ReplicationResponse.ShardInfo.Failure(
shardId,
null,
e,
ExceptionsHelper.status(e),
true
);
failures = new ReplicationResponse.ShardInfo.Failure[totalNumCopies];
Arrays.fill(failures, failure);
}
shardResponse.setShardInfo(new ReplicationResponse.ShardInfo(totalNumCopies, 0, failures));
shardsResponses.add(shardResponse);
private synchronized void addShardResponse(int numCopies, int successful, List<DefaultShardOperationFailedException> failures) {
totalShardCopyCount += numCopies;
successShardCopyCount += successful;
allFailures.addAll(failures);
}

void finish(ActionListener<Response> listener) {
// no need for synchronized here, the RefCountingRunnable guarantees that all the addShardResponse calls happen-before here
logger.trace("{}: got all shard responses", actionName);
listener.onResponse(newResponse(successShardCopyCount, allFailures.size(), totalShardCopyCount, allFailures));
}

class ReplicationResponseActionListener implements ActionListener<ShardResponse> {
private final ShardId shardId;
private final Map<String, IndexMetadata> indexMetadataByName;

ReplicationResponseActionListener(ShardId shardId, Map<String, IndexMetadata> indexMetadataByName) {
this.shardId = shardId;
this.indexMetadataByName = indexMetadataByName;
}

@Override
public void onResponse(ShardResponse shardResponse) {
assert shardResponse != null;
logger.trace("{}: got response from {}", actionName, shardId);
addShardResponse(
shardResponse.getShardInfo().getTotal(),
shardResponse.getShardInfo().getSuccessful(),
Arrays.stream(shardResponse.getShardInfo().getFailures())
.map(
f -> new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(shardId, f.getCause())
)
)
.toList()
);
}

@Override
public void onFailure(Exception e) {
logger.trace("{}: got failure from {}", actionName, shardId);
final int numCopies = indexMetadataByName.get(shardId.getIndexName()).getNumberOfReplicas() + 1;
addShardResponse(numCopies, 0, createSyntheticFailures(numCopies, e));
}

private List<DefaultShardOperationFailedException> createSyntheticFailures(int numCopies, Exception e) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if it really deserves a dedicated method, we can probably include this in onFailure()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was better at some point in the process but no longer needed indeed.

if (TransportActions.isShardNotAvailableException(e)) {
return List.of();
}
};
shardExecute(task, request, shardId, ActionListener.releaseAfter(shardActionListener, refs.acquire()));

final var failures = new DefaultShardOperationFailedException[numCopies];
Arrays.fill(failures, new DefaultShardOperationFailedException(new BroadcastShardOperationFailedException(shardId, e)));
return Arrays.asList(failures);
}
}
}

};
}

protected void shardExecute(Task task, Request request, ShardId shardId, ActionListener<ShardResponse> shardActionListener) {
assert Transports.assertNotTransportThread("may hit all the shards");
ShardRequest shardRequest = newShardRequest(request, shardId);
shardRequest.setParentTask(clusterService.localNode().getId(), task.getId());
client.executeLocally(replicatedBroadcastShardAction, shardRequest, shardActionListener);
Expand All @@ -119,6 +175,7 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL
* @return all shard ids the request should run on
*/
protected List<ShardId> shards(Request request, ClusterState clusterState) {
assert Transports.assertNotTransportThread("may hit all the shards");
List<ShardId> shardIds = new ArrayList<>();
String[] concreteIndices = indexNameExpressionResolver.concreteIndexNames(clusterState, request);
for (String index : concreteIndices) {
Expand All @@ -133,43 +190,13 @@ protected List<ShardId> shards(Request request, ClusterState clusterState) {
return shardIds;
}

protected abstract ShardResponse newShardResponse();

protected abstract ShardRequest newShardRequest(Request request, ShardId shardId);

private void finishAndNotifyListener(ActionListener<Response> listener, CopyOnWriteArrayList<ShardResponse> shardsResponses) {
logger.trace("{}: got all shard responses", actionName);
int successfulShards = 0;
int failedShards = 0;
int totalNumCopies = 0;
List<DefaultShardOperationFailedException> shardFailures = null;
for (int i = 0; i < shardsResponses.size(); i++) {
ReplicationResponse shardResponse = shardsResponses.get(i);
if (shardResponse == null) {
// non active shard, ignore
} else {
failedShards += shardResponse.getShardInfo().getFailed();
successfulShards += shardResponse.getShardInfo().getSuccessful();
totalNumCopies += shardResponse.getShardInfo().getTotal();
if (shardFailures == null) {
shardFailures = new ArrayList<>();
}
for (ReplicationResponse.ShardInfo.Failure failure : shardResponse.getShardInfo().getFailures()) {
shardFailures.add(
new DefaultShardOperationFailedException(
new BroadcastShardOperationFailedException(failure.fullShardId(), failure.getCause())
)
);
}
}
}
listener.onResponse(newResponse(successfulShards, failedShards, totalNumCopies, shardFailures));
}

protected abstract Response newResponse(
int successfulShards,
int failedShards,
int totalNumCopies,
List<DefaultShardOperationFailedException> shardFailures
);

}
Original file line number Diff line number Diff line change
Expand Up @@ -254,15 +254,11 @@ private class TestBroadcastReplicationAction extends TransportBroadcastReplicati
null,
actionFilters,
indexNameExpressionResolver,
null
null,
ThreadPool.Names.SAME
);
}

@Override
protected ReplicationResponse newShardResponse() {
return new ReplicationResponse();
}

@Override
protected BasicReplicationRequest newShardRequest(DummyBroadcastRequest request, ShardId shardId) {
return new BasicReplicationRequest(shardId);
Expand Down