8
8
9
9
package org .elasticsearch .action .support .replication ;
10
10
11
- import org .elasticsearch .ExceptionsHelper ;
12
11
import org .elasticsearch .action .ActionListener ;
12
+ import org .elasticsearch .action .ActionRunnable ;
13
13
import org .elasticsearch .action .ActionType ;
14
14
import org .elasticsearch .action .support .ActionFilters ;
15
15
import org .elasticsearch .action .support .DefaultShardOperationFailedException ;
26
26
import org .elasticsearch .cluster .routing .IndexRoutingTable ;
27
27
import org .elasticsearch .cluster .service .ClusterService ;
28
28
import org .elasticsearch .common .io .stream .Writeable ;
29
+ import org .elasticsearch .core .CheckedConsumer ;
29
30
import org .elasticsearch .index .shard .ShardId ;
30
31
import org .elasticsearch .tasks .Task ;
31
32
import org .elasticsearch .transport .TransportService ;
33
+ import org .elasticsearch .transport .Transports ;
32
34
33
35
import java .util .ArrayList ;
34
36
import java .util .Arrays ;
35
37
import java .util .List ;
36
- import java .util .concurrent . CopyOnWriteArrayList ;
38
+ import java .util .Map ;
37
39
38
40
/**
39
41
* Base class for requests that should be executed on all shards of an index or several indices.
@@ -49,6 +51,7 @@ public abstract class TransportBroadcastReplicationAction<
49
51
private final ClusterService clusterService ;
50
52
private final IndexNameExpressionResolver indexNameExpressionResolver ;
51
53
private final NodeClient client ;
54
+ private final String executor ;
52
55
53
56
public TransportBroadcastReplicationAction (
54
57
String name ,
@@ -58,58 +61,112 @@ public TransportBroadcastReplicationAction(
58
61
NodeClient client ,
59
62
ActionFilters actionFilters ,
60
63
IndexNameExpressionResolver indexNameExpressionResolver ,
61
- ActionType <ShardResponse > replicatedBroadcastShardAction
64
+ ActionType <ShardResponse > replicatedBroadcastShardAction ,
65
+ String executor
62
66
) {
63
67
super (name , transportService , actionFilters , requestReader );
64
68
this .client = client ;
65
69
this .replicatedBroadcastShardAction = replicatedBroadcastShardAction ;
66
70
this .clusterService = clusterService ;
67
71
this .indexNameExpressionResolver = indexNameExpressionResolver ;
72
+ this .executor = executor ;
68
73
}
69
74
70
75
@ Override
71
76
protected void doExecute (Task task , Request request , ActionListener <Response > listener ) {
72
- final ClusterState clusterState = clusterService .state ();
73
- List <ShardId > shards = shards (request , clusterState );
74
- final CopyOnWriteArrayList <ShardResponse > shardsResponses = new CopyOnWriteArrayList <>();
75
- try (var refs = new RefCountingRunnable (() -> finishAndNotifyListener (listener , shardsResponses ))) {
76
- for (final ShardId shardId : shards ) {
77
- ActionListener <ShardResponse > shardActionListener = new ActionListener <ShardResponse >() {
78
- @ Override
79
- public void onResponse (ShardResponse shardResponse ) {
80
- shardsResponses .add (shardResponse );
81
- logger .trace ("{}: got response from {}" , actionName , shardId );
77
+ clusterService .threadPool ().executor (executor ).execute (ActionRunnable .wrap (listener , createAsyncAction (task , request )));
78
+ }
79
+
80
+ private CheckedConsumer <ActionListener <Response >, Exception > createAsyncAction (Task task , Request request ) {
81
+ return new CheckedConsumer <ActionListener <Response >, Exception >() {
82
+
83
+ private int totalShardCopyCount ;
84
+ private int successShardCopyCount ;
85
+ private final List <DefaultShardOperationFailedException > allFailures = new ArrayList <>();
86
+
87
+ @ Override
88
+ public void accept (ActionListener <Response > listener ) {
89
+ assert totalShardCopyCount == 0 && successShardCopyCount == 0 && allFailures .isEmpty () : "shouldn't call this twice" ;
90
+
91
+ final ClusterState clusterState = clusterService .state ();
92
+ final List <ShardId > shards = shards (request , clusterState );
93
+ final Map <String , IndexMetadata > indexMetadataByName = clusterState .getMetadata ().indices ();
94
+
95
+ try (var refs = new RefCountingRunnable (() -> finish (listener ))) {
96
+ for (final ShardId shardId : shards ) {
97
+ // NB This sends O(#shards) requests in a tight loop; TODO add some throttling here?
98
+ shardExecute (
99
+ task ,
100
+ request ,
101
+ shardId ,
102
+ ActionListener .releaseAfter (new ReplicationResponseActionListener (shardId , indexMetadataByName ), refs .acquire ())
103
+ );
82
104
}
105
+ }
106
+ }
107
+
108
+ private synchronized void addShardResponse (int numCopies , int successful , List <DefaultShardOperationFailedException > failures ) {
109
+ totalShardCopyCount += numCopies ;
110
+ successShardCopyCount += successful ;
111
+ allFailures .addAll (failures );
112
+ }
113
+
114
+ void finish (ActionListener <Response > listener ) {
115
+ // no need for synchronized here, the RefCountingRunnable guarantees that all the addShardResponse calls happen-before here
116
+ logger .trace ("{}: got all shard responses" , actionName );
117
+ listener .onResponse (newResponse (successShardCopyCount , allFailures .size (), totalShardCopyCount , allFailures ));
118
+ }
119
+
120
+ class ReplicationResponseActionListener implements ActionListener <ShardResponse > {
121
+ private final ShardId shardId ;
122
+ private final Map <String , IndexMetadata > indexMetadataByName ;
123
+
124
+ ReplicationResponseActionListener (ShardId shardId , Map <String , IndexMetadata > indexMetadataByName ) {
125
+ this .shardId = shardId ;
126
+ this .indexMetadataByName = indexMetadataByName ;
127
+ }
83
128
84
- @ Override
85
- public void onFailure (Exception e ) {
86
- logger .trace ("{}: got failure from {}" , actionName , shardId );
87
- int totalNumCopies = clusterState .getMetadata ().getIndexSafe (shardId .getIndex ()).getNumberOfReplicas () + 1 ;
88
- ShardResponse shardResponse = newShardResponse ();
89
- ReplicationResponse .ShardInfo .Failure [] failures ;
90
- if (TransportActions .isShardNotAvailableException (e )) {
91
- failures = new ReplicationResponse .ShardInfo .Failure [0 ];
92
- } else {
93
- ReplicationResponse .ShardInfo .Failure failure = new ReplicationResponse .ShardInfo .Failure (
94
- shardId ,
95
- null ,
96
- e ,
97
- ExceptionsHelper .status (e ),
98
- true
99
- );
100
- failures = new ReplicationResponse .ShardInfo .Failure [totalNumCopies ];
101
- Arrays .fill (failures , failure );
102
- }
103
- shardResponse .setShardInfo (new ReplicationResponse .ShardInfo (totalNumCopies , 0 , failures ));
104
- shardsResponses .add (shardResponse );
129
+ @ Override
130
+ public void onResponse (ShardResponse shardResponse ) {
131
+ assert shardResponse != null ;
132
+ logger .trace ("{}: got response from {}" , actionName , shardId );
133
+ addShardResponse (
134
+ shardResponse .getShardInfo ().getTotal (),
135
+ shardResponse .getShardInfo ().getSuccessful (),
136
+ Arrays .stream (shardResponse .getShardInfo ().getFailures ())
137
+ .map (
138
+ f -> new DefaultShardOperationFailedException (
139
+ new BroadcastShardOperationFailedException (shardId , f .getCause ())
140
+ )
141
+ )
142
+ .toList ()
143
+ );
144
+ }
145
+
146
+ @ Override
147
+ public void onFailure (Exception e ) {
148
+ logger .trace ("{}: got failure from {}" , actionName , shardId );
149
+ final int numCopies = indexMetadataByName .get (shardId .getIndexName ()).getNumberOfReplicas () + 1 ;
150
+ final List <DefaultShardOperationFailedException > result ;
151
+ if (TransportActions .isShardNotAvailableException (e )) {
152
+ result = List .of ();
153
+ } else {
154
+ final var failures = new DefaultShardOperationFailedException [numCopies ];
155
+ Arrays .fill (
156
+ failures ,
157
+ new DefaultShardOperationFailedException (new BroadcastShardOperationFailedException (shardId , e ))
158
+ );
159
+ result = Arrays .asList (failures );
105
160
}
106
- } ;
107
- shardExecute ( task , request , shardId , ActionListener . releaseAfter ( shardActionListener , refs . acquire ()));
161
+ addShardResponse ( numCopies , 0 , result ) ;
162
+ }
108
163
}
109
- }
164
+
165
+ };
110
166
}
111
167
112
168
protected void shardExecute (Task task , Request request , ShardId shardId , ActionListener <ShardResponse > shardActionListener ) {
169
+ assert Transports .assertNotTransportThread ("may hit all the shards" );
113
170
ShardRequest shardRequest = newShardRequest (request , shardId );
114
171
shardRequest .setParentTask (clusterService .localNode ().getId (), task .getId ());
115
172
client .executeLocally (replicatedBroadcastShardAction , shardRequest , shardActionListener );
@@ -119,6 +176,7 @@ protected void shardExecute(Task task, Request request, ShardId shardId, ActionL
119
176
* @return all shard ids the request should run on
120
177
*/
121
178
protected List <ShardId > shards (Request request , ClusterState clusterState ) {
179
+ assert Transports .assertNotTransportThread ("may hit all the shards" );
122
180
List <ShardId > shardIds = new ArrayList <>();
123
181
String [] concreteIndices = indexNameExpressionResolver .concreteIndexNames (clusterState , request );
124
182
for (String index : concreteIndices ) {
@@ -133,43 +191,13 @@ protected List<ShardId> shards(Request request, ClusterState clusterState) {
133
191
return shardIds ;
134
192
}
135
193
136
- protected abstract ShardResponse newShardResponse ();
137
-
138
194
protected abstract ShardRequest newShardRequest (Request request , ShardId shardId );
139
195
140
- private void finishAndNotifyListener (ActionListener <Response > listener , CopyOnWriteArrayList <ShardResponse > shardsResponses ) {
141
- logger .trace ("{}: got all shard responses" , actionName );
142
- int successfulShards = 0 ;
143
- int failedShards = 0 ;
144
- int totalNumCopies = 0 ;
145
- List <DefaultShardOperationFailedException > shardFailures = null ;
146
- for (int i = 0 ; i < shardsResponses .size (); i ++) {
147
- ReplicationResponse shardResponse = shardsResponses .get (i );
148
- if (shardResponse == null ) {
149
- // non active shard, ignore
150
- } else {
151
- failedShards += shardResponse .getShardInfo ().getFailed ();
152
- successfulShards += shardResponse .getShardInfo ().getSuccessful ();
153
- totalNumCopies += shardResponse .getShardInfo ().getTotal ();
154
- if (shardFailures == null ) {
155
- shardFailures = new ArrayList <>();
156
- }
157
- for (ReplicationResponse .ShardInfo .Failure failure : shardResponse .getShardInfo ().getFailures ()) {
158
- shardFailures .add (
159
- new DefaultShardOperationFailedException (
160
- new BroadcastShardOperationFailedException (failure .fullShardId (), failure .getCause ())
161
- )
162
- );
163
- }
164
- }
165
- }
166
- listener .onResponse (newResponse (successfulShards , failedShards , totalNumCopies , shardFailures ));
167
- }
168
-
169
196
protected abstract Response newResponse (
170
197
int successfulShards ,
171
198
int failedShards ,
172
199
int totalNumCopies ,
173
200
List <DefaultShardOperationFailedException > shardFailures
174
201
);
202
+
175
203
}
0 commit comments