39
39
import org .elasticsearch .search .SearchHit ;
40
40
import org .elasticsearch .search .SearchHits ;
41
41
import org .elasticsearch .search .SearchPhaseResult ;
42
+ import org .elasticsearch .search .SearchService ;
42
43
import org .elasticsearch .search .SearchShardTarget ;
43
44
import org .elasticsearch .search .aggregations .InternalAggregation ;
44
45
import org .elasticsearch .search .aggregations .InternalAggregation .ReduceContext ;
65
66
import java .util .Map ;
66
67
import java .util .function .Function ;
67
68
import java .util .function .IntFunction ;
69
+ import java .util .stream .Collectors ;
68
70
69
71
public final class SearchPhaseController {
70
72
@@ -427,6 +429,15 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
427
429
return new ReducedQueryPhase (totalHits , topDocsStats .fetchHits , topDocsStats .getMaxScore (),
428
430
false , null , null , null , null , SortedTopDocs .EMPTY , null , numReducePhases , 0 , 0 , true );
429
431
}
432
+ int total = queryResults .size ();
433
+ queryResults = queryResults .stream ()
434
+ .filter (res -> res .queryResult ().isNull () == false )
435
+ .collect (Collectors .toList ());
436
+ String errorMsg = "must have at least one non-empty search result, got 0 out of " + total ;
437
+ assert queryResults .isEmpty () == false : errorMsg ;
438
+ if (queryResults .isEmpty ()) {
439
+ throw new IllegalStateException (errorMsg );
440
+ }
430
441
final QuerySearchResult firstResult = queryResults .stream ().findFirst ().get ().queryResult ();
431
442
final boolean hasSuggest = firstResult .suggest () != null ;
432
443
final boolean hasProfileResults = firstResult .hasProfileResults ();
@@ -497,6 +508,18 @@ private ReducedQueryPhase reducedQueryPhase(Collection<? extends SearchPhaseResu
497
508
firstResult .sortValueFormats (), numReducePhases , size , from , false );
498
509
}
499
510
511
+ /*
512
+ * Returns the size of the requested top documents (from + size)
513
+ */
514
+ static int getTopDocsSize (SearchRequest request ) {
515
+ if (request .source () == null ) {
516
+ return SearchService .DEFAULT_SIZE ;
517
+ }
518
+ SearchSourceBuilder source = request .source ();
519
+ return (source .size () == -1 ? SearchService .DEFAULT_SIZE : source .size ()) +
520
+ (source .from () == -1 ? SearchService .DEFAULT_FROM : source .from ());
521
+ }
522
+
500
523
public static final class ReducedQueryPhase {
501
524
// the sum of all hits across all reduces shards
502
525
final TotalHits totalHits ;
@@ -576,6 +599,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
576
599
private final SearchProgressListener progressListener ;
577
600
private int numReducePhases = 0 ;
578
601
private final TopDocsStats topDocsStats ;
602
+ private final int topNSize ;
579
603
private final boolean performFinalReduce ;
580
604
581
605
/**
@@ -589,7 +613,7 @@ static final class QueryPhaseResultConsumer extends ArraySearchPhaseResults<Sear
589
613
*/
590
614
private QueryPhaseResultConsumer (SearchProgressListener progressListener , SearchPhaseController controller ,
591
615
int expectedResultSize , int bufferSize , boolean hasTopDocs , boolean hasAggs ,
592
- int trackTotalHitsUpTo , boolean performFinalReduce ) {
616
+ int trackTotalHitsUpTo , int topNSize , boolean performFinalReduce ) {
593
617
super (expectedResultSize );
594
618
if (expectedResultSize != 1 && bufferSize < 2 ) {
595
619
throw new IllegalArgumentException ("buffer size must be >= 2 if there is more than one expected result" );
@@ -610,6 +634,7 @@ private QueryPhaseResultConsumer(SearchProgressListener progressListener, Search
610
634
this .hasAggs = hasAggs ;
611
635
this .bufferSize = bufferSize ;
612
636
this .topDocsStats = new TopDocsStats (trackTotalHitsUpTo );
637
+ this .topNSize = topNSize ;
613
638
this .performFinalReduce = performFinalReduce ;
614
639
}
615
640
@@ -622,36 +647,38 @@ public void consumeResult(SearchPhaseResult result) {
622
647
}
623
648
624
649
private synchronized void consumeInternal (QuerySearchResult querySearchResult ) {
625
- if (index == bufferSize ) {
650
+ if (querySearchResult .isNull () == false ) {
651
+ if (index == bufferSize ) {
652
+ if (hasAggs ) {
653
+ ReduceContext reduceContext = controller .reduceContextFunction .apply (false );
654
+ InternalAggregations reducedAggs = InternalAggregations .topLevelReduce (Arrays .asList (aggsBuffer ), reduceContext );
655
+ Arrays .fill (aggsBuffer , null );
656
+ aggsBuffer [0 ] = reducedAggs ;
657
+ }
658
+ if (hasTopDocs ) {
659
+ TopDocs reducedTopDocs = mergeTopDocs (Arrays .asList (topDocsBuffer ),
660
+ // we have to merge here in the same way we collect on a shard
661
+ topNSize , 0 );
662
+ Arrays .fill (topDocsBuffer , null );
663
+ topDocsBuffer [0 ] = reducedTopDocs ;
664
+ }
665
+ numReducePhases ++;
666
+ index = 1 ;
667
+ if (hasAggs ) {
668
+ progressListener .notifyPartialReduce (progressListener .searchShards (processedShards ),
669
+ topDocsStats .getTotalHits (), aggsBuffer [0 ], numReducePhases );
670
+ }
671
+ }
672
+ final int i = index ++;
626
673
if (hasAggs ) {
627
- ReduceContext reduceContext = controller .reduceContextFunction .apply (false );
628
- InternalAggregations reducedAggs = InternalAggregations .topLevelReduce (Arrays .asList (aggsBuffer ), reduceContext );
629
- Arrays .fill (aggsBuffer , null );
630
- aggsBuffer [0 ] = reducedAggs ;
674
+ aggsBuffer [i ] = (InternalAggregations ) querySearchResult .consumeAggs ();
631
675
}
632
676
if (hasTopDocs ) {
633
- TopDocs reducedTopDocs = mergeTopDocs (Arrays .asList (topDocsBuffer ),
634
- // we have to merge here in the same way we collect on a shard
635
- querySearchResult .from () + querySearchResult .size (), 0 );
636
- Arrays .fill (topDocsBuffer , null );
637
- topDocsBuffer [0 ] = reducedTopDocs ;
677
+ final TopDocsAndMaxScore topDocs = querySearchResult .consumeTopDocs (); // can't be null
678
+ topDocsStats .add (topDocs , querySearchResult .searchTimedOut (), querySearchResult .terminatedEarly ());
679
+ setShardIndex (topDocs .topDocs , querySearchResult .getShardIndex ());
680
+ topDocsBuffer [i ] = topDocs .topDocs ;
638
681
}
639
- numReducePhases ++;
640
- index = 1 ;
641
- if (hasAggs ) {
642
- progressListener .notifyPartialReduce (progressListener .searchShards (processedShards ),
643
- topDocsStats .getTotalHits (), aggsBuffer [0 ], numReducePhases );
644
- }
645
- }
646
- final int i = index ++;
647
- if (hasAggs ) {
648
- aggsBuffer [i ] = (InternalAggregations ) querySearchResult .consumeAggs ();
649
- }
650
- if (hasTopDocs ) {
651
- final TopDocsAndMaxScore topDocs = querySearchResult .consumeTopDocs (); // can't be null
652
- topDocsStats .add (topDocs , querySearchResult .searchTimedOut (), querySearchResult .terminatedEarly ());
653
- setShardIndex (topDocs .topDocs , querySearchResult .getShardIndex ());
654
- topDocsBuffer [i ] = topDocs .topDocs ;
655
682
}
656
683
processedShards [querySearchResult .getShardIndex ()] = querySearchResult .getSearchShardTarget ();
657
684
}
@@ -706,9 +733,10 @@ ArraySearchPhaseResults<SearchPhaseResult> newSearchPhaseResults(SearchProgressL
706
733
if (isScrollRequest == false && (hasAggs || hasTopDocs )) {
707
734
// no incremental reduce if scroll is used - we only hit a single shard or sometimes more...
708
735
if (request .getBatchedReduceSize () < numShards ) {
736
+ int topNSize = getTopDocsSize (request );
709
737
// only use this if there are aggs and if there are more shards than we should reduce at once
710
738
return new QueryPhaseResultConsumer (listener , this , numShards , request .getBatchedReduceSize (), hasTopDocs , hasAggs ,
711
- trackTotalHitsUpTo , request .isFinalReduce ());
739
+ trackTotalHitsUpTo , topNSize , request .isFinalReduce ());
712
740
}
713
741
}
714
742
return new ArraySearchPhaseResults <SearchPhaseResult >(numShards ) {
@@ -731,7 +759,7 @@ ReducedQueryPhase reduce() {
731
759
732
760
static final class TopDocsStats {
733
761
final int trackTotalHitsUpTo ;
734
- private long totalHits ;
762
+ long totalHits ;
735
763
private TotalHits .Relation totalHitsRelation ;
736
764
long fetchHits ;
737
765
private float maxScore = Float .NEGATIVE_INFINITY ;
0 commit comments