36
36
import org .apache .lucene .util .BitSetIterator ;
37
37
import org .apache .lucene .util .Bits ;
38
38
import org .apache .lucene .util .SparseFixedBitSet ;
39
- import org .elasticsearch .common .util .concurrent .ConcurrentCollections ;
40
39
import org .elasticsearch .core .Releasable ;
41
40
import org .elasticsearch .lucene .util .CombinedBitSet ;
42
41
import org .elasticsearch .search .dfs .AggregatedDfs ;
53
52
import java .util .List ;
54
53
import java .util .Objects ;
55
54
import java .util .PriorityQueue ;
56
- import java .util .Set ;
57
55
import java .util .concurrent .Callable ;
58
56
import java .util .concurrent .Executor ;
59
57
import java .util .stream .Collectors ;
@@ -82,7 +80,6 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
82
80
// don't create slices with less than this number of docs
83
81
private final int minimumDocsPerSlice ;
84
82
85
- private final Set <Thread > timeoutOverwrites = ConcurrentCollections .newConcurrentSet ();
86
83
private volatile boolean timeExceeded = false ;
87
84
88
85
/** constructor for non-concurrent search */
@@ -374,6 +371,8 @@ private <C extends Collector, T> T search(Weight weight, CollectorManager<C, T>
374
371
}
375
372
}
376
373
374
+ private static final ThreadLocal <Boolean > timeoutOverwrites = ThreadLocal .withInitial (() -> false );
375
+
377
376
/**
378
377
* Similar to the lucene implementation, with the following changes made:
379
378
* 1) postCollection is performed after each segment is collected. This is needed for aggregations, performed by search threads
@@ -397,12 +396,12 @@ public void search(LeafReaderContextPartition[] leaves, Weight weight, Collector
397
396
try {
398
397
// Search phase has finished, no longer need to check for timeout
399
398
// otherwise the aggregation post-collection phase might get cancelled.
400
- boolean added = timeoutOverwrites .add ( Thread . currentThread ()) ;
401
- assert added ;
399
+ assert timeoutOverwrites .get () == false ;
400
+ timeoutOverwrites . set ( true ) ;
402
401
doAggregationPostCollection (collector );
403
402
} finally {
404
- boolean removed = timeoutOverwrites .remove ( Thread . currentThread () );
405
- assert removed ;
403
+ assert timeoutOverwrites .get ( );
404
+ timeoutOverwrites . set ( false ) ;
406
405
}
407
406
}
408
407
}
@@ -420,7 +419,7 @@ public boolean timeExceeded() {
420
419
}
421
420
422
421
public void throwTimeExceededException () {
423
- if (timeoutOverwrites .contains ( Thread . currentThread () ) == false ) {
422
+ if (timeoutOverwrites .get ( ) == false ) {
424
423
throw new TimeExceededException ();
425
424
}
426
425
}
0 commit comments