Skip to content

Commit b66ce89

Browse files
committed
Avoid sending duplicate remote failed shard requests (#31313)
Today if a write replication request fails, we will send a shard-failed message to the master node to fail that replica. However, if there are many ongoing write replication requests and the master node is busy, we might overwhelm the cluster and the master node with many shard-failed requests. This commit tries to minimize the shard-failed requests in the above scenario by caching the ongoing shard-failed requests. This issue was discussed at https://discuss.elastic.co/t/half-dead-node-lead-to-cluster-hang/113658/25.
1 parent 1a3eac0 commit b66ce89

File tree

3 files changed

+299
-34
lines changed

3 files changed

+299
-34
lines changed

server/src/main/java/org/elasticsearch/cluster/action/shard/ShardStateAction.java

Lines changed: 120 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@
2525
import org.elasticsearch.ExceptionsHelper;
2626
import org.elasticsearch.Version;
2727
import org.elasticsearch.cluster.ClusterChangedEvent;
28-
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
2928
import org.elasticsearch.cluster.ClusterState;
3029
import org.elasticsearch.cluster.ClusterStateObserver;
3130
import org.elasticsearch.cluster.ClusterStateTaskConfig;
31+
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
3232
import org.elasticsearch.cluster.ClusterStateTaskListener;
3333
import org.elasticsearch.cluster.MasterNodeChangePredicate;
3434
import org.elasticsearch.cluster.NotMasterException;
@@ -48,6 +48,7 @@
4848
import org.elasticsearch.common.io.stream.StreamOutput;
4949
import org.elasticsearch.common.settings.Settings;
5050
import org.elasticsearch.common.unit.TimeValue;
51+
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
5152
import org.elasticsearch.discovery.Discovery;
5253
import org.elasticsearch.index.shard.ShardId;
5354
import org.elasticsearch.node.NodeClosedException;
@@ -68,7 +69,9 @@
6869
import java.util.HashSet;
6970
import java.util.List;
7071
import java.util.Locale;
72+
import java.util.Objects;
7173
import java.util.Set;
74+
import java.util.concurrent.ConcurrentMap;
7275
import java.util.function.Predicate;
7376

7477
public class ShardStateAction extends AbstractComponent {
@@ -80,6 +83,10 @@ public class ShardStateAction extends AbstractComponent {
8083
private final ClusterService clusterService;
8184
private final ThreadPool threadPool;
8285

86+
// a list of shards that failed during replication
87+
// we keep track of these shards in order to avoid sending duplicate failed shard requests for a single failing shard.
88+
private final ConcurrentMap<FailedShardEntry, CompositeListener> remoteFailedShardsCache = ConcurrentCollections.newConcurrentMap();
89+
8390
@Inject
8491
public ShardStateAction(Settings settings, ClusterService clusterService, TransportService transportService,
8592
AllocationService allocationService, RoutingService routingService, ThreadPool threadPool) {
@@ -146,8 +153,35 @@ private static boolean isMasterChannelException(TransportException exp) {
146153
*/
147154
public void remoteShardFailed(final ShardId shardId, String allocationId, long primaryTerm, boolean markAsStale, final String message, @Nullable final Exception failure, Listener listener) {
148155
assert primaryTerm > 0L : "primary term should be strictly positive";
149-
FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
150-
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, listener);
156+
final FailedShardEntry shardEntry = new FailedShardEntry(shardId, allocationId, primaryTerm, message, failure, markAsStale);
157+
final CompositeListener compositeListener = new CompositeListener(listener);
158+
final CompositeListener existingListener = remoteFailedShardsCache.putIfAbsent(shardEntry, compositeListener);
159+
if (existingListener == null) {
160+
sendShardAction(SHARD_FAILED_ACTION_NAME, clusterService.state(), shardEntry, new Listener() {
161+
@Override
162+
public void onSuccess() {
163+
try {
164+
compositeListener.onSuccess();
165+
} finally {
166+
remoteFailedShardsCache.remove(shardEntry);
167+
}
168+
}
169+
@Override
170+
public void onFailure(Exception e) {
171+
try {
172+
compositeListener.onFailure(e);
173+
} finally {
174+
remoteFailedShardsCache.remove(shardEntry);
175+
}
176+
}
177+
});
178+
} else {
179+
existingListener.addListener(listener);
180+
}
181+
}
182+
183+
int remoteShardFailedCacheSize() {
184+
return remoteFailedShardsCache.size();
151185
}
152186

153187
/**
@@ -414,6 +448,23 @@ public String toString() {
414448
components.add("markAsStale [" + markAsStale + "]");
415449
return String.join(", ", components);
416450
}
451+
452+
@Override
453+
public boolean equals(Object o) {
454+
if (this == o) return true;
455+
if (o == null || getClass() != o.getClass()) return false;
456+
FailedShardEntry that = (FailedShardEntry) o;
457+
// Exclude message and exception from equals and hashCode
458+
return Objects.equals(this.shardId, that.shardId) &&
459+
Objects.equals(this.allocationId, that.allocationId) &&
460+
primaryTerm == that.primaryTerm &&
461+
markAsStale == that.markAsStale;
462+
}
463+
464+
@Override
465+
public int hashCode() {
466+
return Objects.hash(shardId, allocationId, primaryTerm, markAsStale);
467+
}
417468
}
418469

419470
public void shardStarted(final ShardRouting shardRouting, final String message, Listener listener) {
@@ -585,6 +636,72 @@ default void onFailure(final Exception e) {
585636

586637
}
587638

639+
/**
640+
* A composite listener that allows registering multiple listeners dynamically.
641+
*/
642+
static final class CompositeListener implements Listener {
643+
private boolean isNotified = false;
644+
private Exception failure = null;
645+
private final List<Listener> listeners = new ArrayList<>();
646+
647+
CompositeListener(Listener listener) {
648+
listeners.add(listener);
649+
}
650+
651+
void addListener(Listener listener) {
652+
final boolean ready;
653+
synchronized (this) {
654+
ready = this.isNotified;
655+
if (ready == false) {
656+
listeners.add(listener);
657+
}
658+
}
659+
if (ready) {
660+
if (failure != null) {
661+
listener.onFailure(failure);
662+
} else {
663+
listener.onSuccess();
664+
}
665+
}
666+
}
667+
668+
private void onCompleted(Exception failure) {
669+
synchronized (this) {
670+
this.failure = failure;
671+
this.isNotified = true;
672+
}
673+
RuntimeException firstException = null;
674+
for (Listener listener : listeners) {
675+
try {
676+
if (failure != null) {
677+
listener.onFailure(failure);
678+
} else {
679+
listener.onSuccess();
680+
}
681+
} catch (RuntimeException innerEx) {
682+
if (firstException == null) {
683+
firstException = innerEx;
684+
} else {
685+
firstException.addSuppressed(innerEx);
686+
}
687+
}
688+
}
689+
if (firstException != null) {
690+
throw firstException;
691+
}
692+
}
693+
694+
@Override
695+
public void onSuccess() {
696+
onCompleted(null);
697+
}
698+
699+
@Override
700+
public void onFailure(Exception failure) {
701+
onCompleted(failure);
702+
}
703+
}
704+
588705
public static class NoLongerPrimaryShardException extends ElasticsearchException {
589706

590707
public NoLongerPrimaryShardException(ShardId shardId, String msg) {

server/src/test/java/org/elasticsearch/cluster/action/shard/ShardFailedClusterStateTaskExecutorTests.java

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import com.carrotsearch.hppc.cursors.ObjectCursor;
2323
import org.apache.lucene.index.CorruptIndexException;
2424
import org.elasticsearch.Version;
25-
import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry;
2625
import org.elasticsearch.cluster.ClusterName;
2726
import org.elasticsearch.cluster.ClusterState;
2827
import org.elasticsearch.cluster.ClusterStateTaskExecutor;
2928
import org.elasticsearch.cluster.ESAllocationTestCase;
29+
import org.elasticsearch.cluster.action.shard.ShardStateAction.FailedShardEntry;
3030
import org.elasticsearch.cluster.metadata.IndexMetaData;
3131
import org.elasticsearch.cluster.metadata.MetaData;
3232
import org.elasticsearch.cluster.node.DiscoveryNodes;
@@ -43,6 +43,7 @@
4343
import org.elasticsearch.cluster.routing.allocation.StaleShard;
4444
import org.elasticsearch.cluster.routing.allocation.decider.ClusterRebalanceAllocationDecider;
4545
import org.elasticsearch.common.UUIDs;
46+
import org.elasticsearch.common.collect.Tuple;
4647
import org.elasticsearch.common.settings.Settings;
4748
import org.elasticsearch.common.util.set.Sets;
4849
import org.elasticsearch.index.Index;
@@ -53,9 +54,7 @@
5354
import java.util.Arrays;
5455
import java.util.Collections;
5556
import java.util.List;
56-
import java.util.Map;
5757
import java.util.Set;
58-
import java.util.function.Function;
5958
import java.util.stream.Collectors;
6059
import java.util.stream.IntStream;
6160

@@ -131,10 +130,15 @@ ClusterState applyFailedShards(ClusterState currentState, List<FailedShard> fail
131130
tasks.addAll(failingTasks);
132131
tasks.addAll(nonExistentTasks);
133132
ClusterStateTaskExecutor.ClusterTasksResult<FailedShardEntry> result = failingExecutor.execute(currentState, tasks);
134-
Map<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
135-
failingTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure"))));
136-
taskResultMap.putAll(nonExistentTasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success())));
137-
assertTaskResults(taskResultMap, result, currentState, false);
133+
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = new ArrayList<>();
134+
for (FailedShardEntry failingTask : failingTasks) {
135+
taskResultList.add(Tuple.tuple(failingTask,
136+
ClusterStateTaskExecutor.TaskResult.failure(new RuntimeException("simulated applyFailedShards failure"))));
137+
}
138+
for (FailedShardEntry nonExistentTask : nonExistentTasks) {
139+
taskResultList.add(Tuple.tuple(nonExistentTask, ClusterStateTaskExecutor.TaskResult.success()));
140+
}
141+
assertTaskResults(taskResultList, result, currentState, false);
138142
}
139143

140144
public void testIllegalShardFailureRequests() throws Exception {
@@ -147,14 +151,14 @@ public void testIllegalShardFailureRequests() throws Exception {
147151
tasks.add(new FailedShardEntry(failingTask.shardId, failingTask.allocationId,
148152
randomIntBetween(1, (int) primaryTerm - 1), failingTask.message, failingTask.failure, randomBoolean()));
149153
}
150-
Map<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
151-
tasks.stream().collect(Collectors.toMap(
152-
Function.identity(),
153-
task -> ClusterStateTaskExecutor.TaskResult.failure(new ShardStateAction.NoLongerPrimaryShardException(task.shardId,
154-
"primary term [" + task.primaryTerm + "] did not match current primary term [" +
155-
currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]"))));
154+
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = tasks.stream()
155+
.map(task -> Tuple.tuple(task, ClusterStateTaskExecutor.TaskResult.failure(
156+
new ShardStateAction.NoLongerPrimaryShardException(task.shardId, "primary term ["
157+
+ task.primaryTerm + "] did not match current primary term ["
158+
+ currentState.metaData().index(task.shardId.getIndex()).primaryTerm(task.shardId.id()) + "]"))))
159+
.collect(Collectors.toList());
156160
ClusterStateTaskExecutor.ClusterTasksResult<FailedShardEntry> result = executor.execute(currentState, tasks);
157-
assertTaskResults(taskResultMap, result, currentState, false);
161+
assertTaskResults(taskResultList, result, currentState, false);
158162
}
159163

160164
public void testMarkAsStaleWhenFailingShard() throws Exception {
@@ -251,44 +255,44 @@ private static void assertTasksSuccessful(
251255
ClusterState clusterState,
252256
boolean clusterStateChanged
253257
) {
254-
Map<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap =
255-
tasks.stream().collect(Collectors.toMap(Function.identity(), task -> ClusterStateTaskExecutor.TaskResult.success()));
256-
assertTaskResults(taskResultMap, result, clusterState, clusterStateChanged);
258+
List<Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList = tasks.stream()
259+
.map(t -> Tuple.tuple(t, ClusterStateTaskExecutor.TaskResult.success())).collect(Collectors.toList());
260+
assertTaskResults(taskResultList, result, clusterState, clusterStateChanged);
257261
}
258262

259263
private static void assertTaskResults(
260-
Map<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> taskResultMap,
264+
List<Tuple<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult>> taskResultList,
261265
ClusterStateTaskExecutor.ClusterTasksResult<ShardStateAction.FailedShardEntry> result,
262266
ClusterState clusterState,
263267
boolean clusterStateChanged
264268
) {
265269
// there should be as many task results as tasks
266-
assertEquals(taskResultMap.size(), result.executionResults.size());
270+
assertEquals(taskResultList.size(), result.executionResults.size());
267271

268-
for (Map.Entry<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultMap.entrySet()) {
272+
for (Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultList) {
269273
// every task should have a corresponding task result
270-
assertTrue(result.executionResults.containsKey(entry.getKey()));
274+
assertTrue(result.executionResults.containsKey(entry.v1()));
271275

272276
// the task results are as expected
273-
assertEquals(entry.getKey().toString(), entry.getValue().isSuccess(), result.executionResults.get(entry.getKey()).isSuccess());
277+
assertEquals(entry.v1().toString(), entry.v2().isSuccess(), result.executionResults.get(entry.v1()).isSuccess());
274278
}
275279

276280
List<ShardRouting> shards = clusterState.getRoutingTable().allShards();
277-
for (Map.Entry<ShardStateAction.FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultMap.entrySet()) {
278-
if (entry.getValue().isSuccess()) {
281+
for (Tuple<FailedShardEntry, ClusterStateTaskExecutor.TaskResult> entry : taskResultList) {
282+
if (entry.v2().isSuccess()) {
279283
// the shard was successfully failed and so should not be in the routing table
280284
for (ShardRouting shard : shards) {
281285
if (shard.assignedToNode()) {
282-
assertFalse("entry key " + entry.getKey() + ", shard routing " + shard,
283-
entry.getKey().getShardId().equals(shard.shardId()) &&
284-
entry.getKey().getAllocationId().equals(shard.allocationId().getId()));
286+
assertFalse("entry key " + entry.v1() + ", shard routing " + shard,
287+
entry.v1().getShardId().equals(shard.shardId()) &&
288+
entry.v1().getAllocationId().equals(shard.allocationId().getId()));
285289
}
286290
}
287291
} else {
288292
// check we saw the expected failure
289-
ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.getKey());
290-
assertThat(actualResult.getFailure(), instanceOf(entry.getValue().getFailure().getClass()));
291-
assertThat(actualResult.getFailure().getMessage(), equalTo(entry.getValue().getFailure().getMessage()));
293+
ClusterStateTaskExecutor.TaskResult actualResult = result.executionResults.get(entry.v1());
294+
assertThat(actualResult.getFailure(), instanceOf(entry.v2().getFailure().getClass()));
295+
assertThat(actualResult.getFailure().getMessage(), equalTo(entry.v2().getFailure().getMessage()));
292296
}
293297
}
294298

0 commit comments

Comments
 (0)