30
30
import org .elasticsearch .search .SearchPhaseResult ;
31
31
import org .elasticsearch .search .SearchShardTarget ;
32
32
33
- import java .io . IOException ;
33
+ import java .util . ArrayDeque ;
34
34
import java .util .ArrayList ;
35
35
import java .util .List ;
36
+ import java .util .Map ;
37
+ import java .util .concurrent .ConcurrentHashMap ;
36
38
import java .util .concurrent .Executor ;
37
39
import java .util .concurrent .atomic .AtomicInteger ;
38
40
import java .util .stream .Stream ;
@@ -52,12 +54,13 @@ abstract class InitialSearchPhase<FirstResult extends SearchPhaseResult> extends
52
54
private final Logger logger ;
53
55
private final int expectedTotalOps ;
54
56
private final AtomicInteger totalOps = new AtomicInteger ();
55
- private final AtomicInteger shardExecutionIndex = new AtomicInteger (0 );
56
- private final int maxConcurrentShardRequests ;
57
+ private final int maxConcurrentRequestsPerNode ;
57
58
private final Executor executor ;
59
+ private final Map <String , PendingExecutions > pendingExecutionsPerNode = new ConcurrentHashMap <>();
60
+ private final boolean throttleConcurrentRequests ;
58
61
59
62
InitialSearchPhase (String name , SearchRequest request , GroupShardsIterator <SearchShardIterator > shardsIts , Logger logger ,
60
- int maxConcurrentShardRequests , Executor executor ) {
63
+ int maxConcurrentRequestsPerNode , Executor executor ) {
61
64
super (name );
62
65
this .request = request ;
63
66
final List <SearchShardIterator > toSkipIterators = new ArrayList <>();
@@ -77,7 +80,9 @@ abstract class InitialSearchPhase<FirstResult extends SearchPhaseResult> extends
77
80
// on a per shards level we use shardIt.remaining() to increment the totalOps pointer but add 1 for the current shard result
78
81
// we process hence we add one for the non active partition here.
79
82
this .expectedTotalOps = shardsIts .totalSizeWith1ForEmpty ();
80
- this .maxConcurrentShardRequests = Math .min (maxConcurrentShardRequests , shardsIts .size ());
83
+ this .maxConcurrentRequestsPerNode = Math .min (maxConcurrentRequestsPerNode , shardsIts .size ());
84
+ // in the case were we have less shards than maxConcurrentRequestsPerNode we don't need to throttle
85
+ this .throttleConcurrentRequests = maxConcurrentRequestsPerNode < shardsIts .size ();
81
86
this .executor = executor ;
82
87
}
83
88
@@ -109,7 +114,6 @@ private void onShardFailure(final int shardIndex, @Nullable ShardRouting shard,
109
114
if (!lastShard ) {
110
115
performPhaseOnShard (shardIndex , shardIt , nextShard );
111
116
} else {
112
- maybeExecuteNext (); // move to the next execution if needed
113
117
// no more shards active, add a failure
114
118
if (logger .isDebugEnabled () && !logger .isTraceEnabled ()) { // do not double log this exception
115
119
if (e != null && !TransportActions .isShardNotAvailableException (e )) {
@@ -123,15 +127,12 @@ private void onShardFailure(final int shardIndex, @Nullable ShardRouting shard,
123
127
}
124
128
125
129
@ Override
126
- public final void run () throws IOException {
130
+ public final void run () {
127
131
for (final SearchShardIterator iterator : toSkipShardsIts ) {
128
132
assert iterator .skip ();
129
133
skipShard (iterator );
130
134
}
131
135
if (shardsIts .size () > 0 ) {
132
- int maxConcurrentShardRequests = Math .min (this .maxConcurrentShardRequests , shardsIts .size ());
133
- final boolean success = shardExecutionIndex .compareAndSet (0 , maxConcurrentShardRequests );
134
- assert success ;
135
136
assert request .allowPartialSearchResults () != null : "SearchRequest missing setting for allowPartialSearchResults" ;
136
137
if (request .allowPartialSearchResults () == false ) {
137
138
final StringBuilder missingShards = new StringBuilder ();
@@ -152,22 +153,14 @@ public final void run() throws IOException {
152
153
throw new SearchPhaseExecutionException (getName (), msg , null , ShardSearchFailure .EMPTY_ARRAY );
153
154
}
154
155
}
155
- for (int index = 0 ; index < maxConcurrentShardRequests ; index ++) {
156
+ for (int index = 0 ; index < shardsIts . size () ; index ++) {
156
157
final SearchShardIterator shardRoutings = shardsIts .get (index );
157
158
assert shardRoutings .skip () == false ;
158
159
performPhaseOnShard (index , shardRoutings , shardRoutings .nextOrNull ());
159
160
}
160
161
}
161
162
}
162
163
163
- private void maybeExecuteNext () {
164
- final int index = shardExecutionIndex .getAndIncrement ();
165
- if (index < shardsIts .size ()) {
166
- final SearchShardIterator shardRoutings = shardsIts .get (index );
167
- performPhaseOnShard (index , shardRoutings , shardRoutings .nextOrNull ());
168
- }
169
- }
170
-
171
164
172
165
private void maybeFork (final Thread thread , final Runnable runnable ) {
173
166
if (thread == Thread .currentThread ()) {
@@ -197,6 +190,59 @@ public boolean isForceExecution() {
197
190
});
198
191
}
199
192
193
+ private static final class PendingExecutions {
194
+ private final int permits ;
195
+ private int permitsTaken = 0 ;
196
+ private ArrayDeque <Runnable > queue = new ArrayDeque <>();
197
+
198
+ PendingExecutions (int permits ) {
199
+ assert permits > 0 : "not enough permits: " + permits ;
200
+ this .permits = permits ;
201
+ }
202
+
203
+ void finishAndRunNext () {
204
+ synchronized (this ) {
205
+ permitsTaken --;
206
+ assert permitsTaken >= 0 : "illegal taken permits: " + permitsTaken ;
207
+ }
208
+ tryRun (null );
209
+ }
210
+
211
+ void tryRun (Runnable runnable ) {
212
+ Runnable r = tryQueue (runnable );
213
+ if (r != null ) {
214
+ r .run ();
215
+ }
216
+ }
217
+
218
+ private synchronized Runnable tryQueue (Runnable runnable ) {
219
+ Runnable toExecute = null ;
220
+ if (permitsTaken < permits ) {
221
+ permitsTaken ++;
222
+ toExecute = runnable ;
223
+ if (toExecute == null ) { // only poll if we don't have anything to execute
224
+ toExecute = queue .poll ();
225
+ }
226
+ if (toExecute == null ) {
227
+ permitsTaken --;
228
+ }
229
+ } else if (runnable != null ) {
230
+ queue .add (runnable );
231
+ }
232
+ return toExecute ;
233
+ }
234
+ }
235
+
236
+ private void executeNext (PendingExecutions pendingExecutions , Thread originalThread ) {
237
+ if (pendingExecutions != null ) {
238
+ assert throttleConcurrentRequests ;
239
+ maybeFork (originalThread , pendingExecutions ::finishAndRunNext );
240
+ } else {
241
+ assert throttleConcurrentRequests == false ;
242
+ }
243
+ }
244
+
245
+
200
246
private void performPhaseOnShard (final int shardIndex , final SearchShardIterator shardIt , final ShardRouting shard ) {
201
247
/*
202
248
* We capture the thread that this phase is starting on. When we are called back after executing the phase, we are either on the
@@ -205,29 +251,54 @@ private void performPhaseOnShard(final int shardIndex, final SearchShardIterator
205
251
* could stack overflow. To prevent this, we fork if we are called back on the same thread that execution started on and otherwise
206
252
* we can continue (cf. InitialSearchPhase#maybeFork).
207
253
*/
208
- final Thread thread = Thread .currentThread ();
209
254
if (shard == null ) {
210
255
fork (() -> onShardFailure (shardIndex , null , null , shardIt , new NoShardAvailableActionException (shardIt .shardId ())));
211
256
} else {
212
- try {
213
- executePhaseOnShard (shardIt , shard , new SearchActionListener <FirstResult >(new SearchShardTarget (shard .currentNodeId (),
214
- shardIt .shardId (), shardIt .getClusterAlias (), shardIt .getOriginalIndices ()), shardIndex ) {
215
- @ Override
216
- public void innerOnResponse (FirstResult result ) {
217
- maybeFork (thread , () -> onShardResult (result , shardIt ));
218
- }
257
+ final PendingExecutions pendingExecutions = throttleConcurrentRequests ?
258
+ pendingExecutionsPerNode .computeIfAbsent (shard .currentNodeId (), n -> new PendingExecutions (maxConcurrentRequestsPerNode ))
259
+ : null ;
260
+ Runnable r = () -> {
261
+ final Thread thread = Thread .currentThread ();
262
+ try {
263
+ executePhaseOnShard (shardIt , shard , new SearchActionListener <FirstResult >(new SearchShardTarget (shard .currentNodeId (),
264
+ shardIt .shardId (), shardIt .getClusterAlias (), shardIt .getOriginalIndices ()), shardIndex ) {
265
+ @ Override
266
+ public void innerOnResponse (FirstResult result ) {
267
+ try {
268
+ onShardResult (result , shardIt );
269
+ } finally {
270
+ executeNext (pendingExecutions , thread );
271
+ }
272
+ }
219
273
220
- @ Override
221
- public void onFailure (Exception t ) {
222
- maybeFork (thread , () -> onShardFailure (shardIndex , shard , shard .currentNodeId (), shardIt , t ));
274
+ @ Override
275
+ public void onFailure (Exception t ) {
276
+ try {
277
+ onShardFailure (shardIndex , shard , shard .currentNodeId (), shardIt , t );
278
+ } finally {
279
+ executeNext (pendingExecutions , thread );
280
+ }
281
+ }
282
+ });
283
+
284
+
285
+ } catch (final Exception e ) {
286
+ try {
287
+ /*
288
+ * It is possible to run into connection exceptions here because we are getting the connection early and might
289
+ * run in tonodes that are not connected. In this case, on shard failure will move us to the next shard copy.
290
+ */
291
+ fork (() -> onShardFailure (shardIndex , shard , shard .currentNodeId (), shardIt , e ));
292
+ } finally {
293
+ executeNext (pendingExecutions , thread );
223
294
}
224
- });
225
- } catch ( final Exception e ) {
226
- /*
227
- * It is possible to run into connection exceptions here because we are getting the connection early and might run in to
228
- * nodes that are not connected. In this case, on shard failure will move us to the next shard copy.
229
- */
230
- fork (() -> onShardFailure ( shardIndex , shard , shard . currentNodeId (), shardIt , e ) );
295
+ }
296
+ };
297
+ if ( pendingExecutions == null ) {
298
+ r . run ();
299
+ } else {
300
+ assert throttleConcurrentRequests ;
301
+ pendingExecutions . tryRun ( r );
231
302
}
232
303
}
233
304
}
@@ -257,8 +328,6 @@ private void successfulShardExecution(SearchShardIterator shardsIt) {
257
328
} else if (xTotalOps > expectedTotalOps ) {
258
329
throw new AssertionError ("unexpected higher total ops [" + xTotalOps + "] compared to expected ["
259
330
+ expectedTotalOps + "]" );
260
- } else if (shardsIt .skip () == false ) {
261
- maybeExecuteNext ();
262
331
}
263
332
}
264
333
@@ -376,5 +445,4 @@ protected void skipShard(SearchShardIterator iterator) {
376
445
assert iterator .skip ();
377
446
successfulShardExecution (iterator );
378
447
}
379
-
380
448
}
0 commit comments