Skip to content

Commit 38d2935

Browse files
authored
Avoid double term construction in DfsPhase (elastic#38716)
DfsPhase captures terms used for scoring a query in order to build global term statistics across multiple shards for more accurate scoring. It currently does this by building the query's `Weight` and calling `extractTerms` on it to collect terms, and then calling `IndexSearcher.termStatistics()` for each collected term. This duplicates work, however, as the various `Weight` implementations will already have collected these statistics at construction time. This commit replaces this round-about way of collecting stats, instead using a delegating IndexSearcher that collects the term contexts and statistics when `IndexSearcher.termStatistics()` is called from the Weight. It also fixes a bug when using rescorers, where a `QueryRescorer` would calculate distributed term statistics, but ignore field statistics. `Rescorer.extractTerms` has been removed, and replaced with a new method on `RescoreContext` that returns any queries used by the rescore implementation. The delegating IndexSearcher then collects term contexts and statistics in the same way described above for each Query.
1 parent b76a380 commit 38d2935

File tree

5 files changed

+59
-115
lines changed

5 files changed

+59
-115
lines changed

plugins/examples/rescore/src/main/java/org/elasticsearch/example/rescore/ExampleRescoreBuilder.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
package org.elasticsearch.example.rescore;
2121

2222
import org.apache.lucene.index.LeafReaderContext;
23-
import org.apache.lucene.index.Term;
2423
import org.apache.lucene.search.Explanation;
2524
import org.apache.lucene.search.IndexSearcher;
2625
import org.apache.lucene.search.ScoreDoc;
@@ -46,7 +45,6 @@
4645
import java.util.Arrays;
4746
import java.util.Iterator;
4847
import java.util.Objects;
49-
import java.util.Set;
5048

5149
import static java.util.Collections.singletonList;
5250
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@@ -224,9 +222,5 @@ public Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreCon
224222
return Explanation.match(context.factor, "test", singletonList(sourceExplanation));
225223
}
226224

227-
@Override
228-
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) {
229-
// Since we don't use queries there are no terms to extract.
230-
}
231225
}
232226
}

server/src/main/java/org/elasticsearch/search/dfs/DfsPhase.java

Lines changed: 37 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,12 @@
1919

2020
package org.elasticsearch.search.dfs;
2121

22-
import com.carrotsearch.hppc.ObjectHashSet;
2322
import com.carrotsearch.hppc.ObjectObjectHashMap;
24-
import com.carrotsearch.hppc.cursors.ObjectCursor;
25-
26-
import org.apache.lucene.index.IndexReaderContext;
2723
import org.apache.lucene.index.Term;
2824
import org.apache.lucene.index.TermStates;
2925
import org.apache.lucene.search.CollectionStatistics;
26+
import org.apache.lucene.search.IndexSearcher;
27+
import org.apache.lucene.search.Query;
3028
import org.apache.lucene.search.ScoreMode;
3129
import org.apache.lucene.search.TermStatistics;
3230
import org.elasticsearch.common.collect.HppcMaps;
@@ -36,9 +34,8 @@
3634
import org.elasticsearch.tasks.TaskCancelledException;
3735

3836
import java.io.IOException;
39-
import java.util.AbstractSet;
40-
import java.util.Collection;
41-
import java.util.Iterator;
37+
import java.util.HashMap;
38+
import java.util.Map;
4239

4340
/**
4441
* Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
@@ -52,101 +49,53 @@ public void preProcess(SearchContext context) {
5249

5350
@Override
5451
public void execute(SearchContext context) {
55-
final ObjectHashSet<Term> termsSet = new ObjectHashSet<>();
5652
try {
57-
context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1f)
58-
.extractTerms(new DelegateSet(termsSet));
53+
ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
54+
Map<Term, TermStatistics> stats = new HashMap<>();
55+
IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
56+
@Override
57+
public TermStatistics termStatistics(Term term, TermStates states) throws IOException {
58+
if (context.isCancelled()) {
59+
throw new TaskCancelledException("cancelled");
60+
}
61+
TermStatistics ts = super.termStatistics(term, states);
62+
if (ts != null) {
63+
stats.put(term, ts);
64+
}
65+
return ts;
66+
}
67+
68+
@Override
69+
public CollectionStatistics collectionStatistics(String field) throws IOException {
70+
if (context.isCancelled()) {
71+
throw new TaskCancelledException("cancelled");
72+
}
73+
CollectionStatistics cs = super.collectionStatistics(field);
74+
if (cs != null) {
75+
fieldStatistics.put(field, cs);
76+
}
77+
return cs;
78+
}
79+
};
80+
81+
searcher.createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1);
5982
for (RescoreContext rescoreContext : context.rescore()) {
60-
try {
61-
rescoreContext.rescorer().extractTerms(context.searcher(), rescoreContext, new DelegateSet(termsSet));
62-
} catch (IOException e) {
63-
throw new IllegalStateException("Failed to extract terms", e);
83+
for (Query query : rescoreContext.getQueries()) {
84+
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
6485
}
6586
}
6687

67-
Term[] terms = termsSet.toArray(Term.class);
88+
Term[] terms = stats.keySet().toArray(new Term[0]);
6889
TermStatistics[] termStatistics = new TermStatistics[terms.length];
69-
IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
7090
for (int i = 0; i < terms.length; i++) {
71-
if(context.isCancelled()) {
72-
throw new TaskCancelledException("cancelled");
73-
}
74-
// LUCENE 4 UPGRADE: cache TermStates?
75-
TermStates termContext = TermStates.build(indexReaderContext, terms[i], true);
76-
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
77-
}
78-
79-
ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
80-
for (Term term : terms) {
81-
assert term.field() != null : "field is null";
82-
if (fieldStatistics.containsKey(term.field()) == false) {
83-
final CollectionStatistics collectionStatistics = context.searcher().collectionStatistics(term.field());
84-
if (collectionStatistics != null) {
85-
fieldStatistics.put(term.field(), collectionStatistics);
86-
}
87-
if(context.isCancelled()) {
88-
throw new TaskCancelledException("cancelled");
89-
}
90-
}
91+
termStatistics[i] = stats.get(terms[i]);
9192
}
9293

9394
context.dfsResult().termsStatistics(terms, termStatistics)
9495
.fieldStatistics(fieldStatistics)
9596
.maxDoc(context.searcher().getIndexReader().maxDoc());
9697
} catch (Exception e) {
9798
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
98-
} finally {
99-
termsSet.clear(); // don't hold on to terms
100-
}
101-
}
102-
103-
// We need to bridge to JCF world, b/c of Query#extractTerms
104-
private static class DelegateSet extends AbstractSet<Term> {
105-
106-
private final ObjectHashSet<Term> delegate;
107-
108-
private DelegateSet(ObjectHashSet<Term> delegate) {
109-
this.delegate = delegate;
110-
}
111-
112-
@Override
113-
public boolean add(Term term) {
114-
return delegate.add(term);
115-
}
116-
117-
@Override
118-
public boolean addAll(Collection<? extends Term> terms) {
119-
boolean result = false;
120-
for (Term term : terms) {
121-
result = delegate.add(term);
122-
}
123-
return result;
124-
}
125-
126-
@Override
127-
public Iterator<Term> iterator() {
128-
final Iterator<ObjectCursor<Term>> iterator = delegate.iterator();
129-
return new Iterator<Term>() {
130-
@Override
131-
public boolean hasNext() {
132-
return iterator.hasNext();
133-
}
134-
135-
@Override
136-
public Term next() {
137-
return iterator.next().value;
138-
}
139-
140-
@Override
141-
public void remove() {
142-
throw new UnsupportedOperationException();
143-
}
144-
};
145-
}
146-
147-
@Override
148-
public int size() {
149-
return delegate.size();
15099
}
151100
}
152101

server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@
1919

2020
package org.elasticsearch.search.rescore;
2121

22-
import org.apache.lucene.index.Term;
2322
import org.apache.lucene.search.Explanation;
2423
import org.apache.lucene.search.IndexSearcher;
2524
import org.apache.lucene.search.Query;
2625
import org.apache.lucene.search.ScoreDoc;
27-
import org.apache.lucene.search.ScoreMode;
2826
import org.apache.lucene.search.TopDocs;
2927

3028
import java.io.IOException;
3129
import java.util.Arrays;
30+
import java.util.Collections;
3231
import java.util.Comparator;
32+
import java.util.List;
3333
import java.util.Set;
34-
import java.util.Collections;
34+
3535
import static java.util.stream.Collectors.toSet;
3636

3737
public final class QueryRescorer implements Rescorer {
@@ -170,6 +170,11 @@ public void setQuery(Query query) {
170170
this.query = query;
171171
}
172172

173+
@Override
174+
public List<Query> getQueries() {
175+
return Collections.singletonList(query);
176+
}
177+
173178
public Query query() {
174179
return query;
175180
}
@@ -203,10 +208,4 @@ public void setScoreMode(String scoreMode) {
203208
}
204209
}
205210

206-
@Override
207-
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException {
208-
Query query = ((QueryRescoreContext) rescoreContext).query();
209-
searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f).extractTerms(termsSet);
210-
}
211-
212211
}

server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
package org.elasticsearch.search.rescore;
2121

22+
import org.apache.lucene.search.Query;
23+
24+
import java.util.Collections;
25+
import java.util.List;
2226
import java.util.Set;
2327

2428
/**
@@ -29,7 +33,7 @@
2933
public class RescoreContext {
3034
private final int windowSize;
3135
private final Rescorer rescorer;
32-
private Set<Integer> resroredDocs; //doc Ids for which rescoring was applied
36+
private Set<Integer> rescoredDocs; //doc Ids for which rescoring was applied
3337

3438
/**
3539
* Build the context.
@@ -55,10 +59,17 @@ public int getWindowSize() {
5559
}
5660

5761
public void setRescoredDocs(Set<Integer> docIds) {
58-
resroredDocs = docIds;
62+
rescoredDocs = docIds;
5963
}
6064

6165
public boolean isRescored(int docId) {
62-
return resroredDocs.contains(docId);
66+
return rescoredDocs.contains(docId);
67+
}
68+
69+
/**
70+
* Returns queries associated with the rescorer
71+
*/
72+
public List<Query> getQueries() {
73+
return Collections.emptyList();
6374
}
6475
}

server/src/main/java/org/elasticsearch/search/rescore/Rescorer.java

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,11 @@
1919

2020
package org.elasticsearch.search.rescore;
2121

22-
import org.apache.lucene.index.Term;
2322
import org.apache.lucene.search.Explanation;
2423
import org.apache.lucene.search.IndexSearcher;
2524
import org.apache.lucene.search.TopDocs;
26-
import org.elasticsearch.action.search.SearchType;
2725

2826
import java.io.IOException;
29-
import java.util.Set;
3027

3128
/**
3229
* A query rescorer interface used to re-rank the Top-K results of a previously
@@ -61,10 +58,4 @@ public interface Rescorer {
6158
Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext,
6259
Explanation sourceExplanation) throws IOException;
6360

64-
/**
65-
* Extracts all terms needed to execute this {@link Rescorer}. This method
66-
* is executed in a distributed frequency collection roundtrip for
67-
* {@link SearchType#DFS_QUERY_THEN_FETCH}
68-
*/
69-
void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException;
7061
}

0 commit comments

Comments
 (0)