Skip to content

Commit 6f3d152

Browse files
authored
Propagate scoring function through random sampler (#116957) (#117165)
* Propagate scoring function through random sampler. * Update docs/changelog/116957.yaml * Correct score mode in random sampler weight * Fix random sampling with scores and p=1.0 * Unit test with scores * YAML test * Add capability
1 parent 6a8ef8f commit 6f3d152

File tree

8 files changed

+150
-56
lines changed

8 files changed

+150
-56
lines changed

docs/changelog/116957.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 116957
2+
summary: Propagate scoring function through random sampler
3+
area: Machine Learning
4+
type: bug
5+
issues: [ 110134 ]

modules/aggregations/build.gradle

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ esplugin {
2020

2121
restResources {
2222
restApi {
23-
include '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
23+
include 'capabilities', '_common', 'indices', 'cluster', 'index', 'search', 'nodes', 'bulk', 'scripts_painless_execute', 'put_script'
2424
}
2525
restTests {
2626
// Pulls in all aggregation tests from core AND the forwards v7's core for forwards compatibility

modules/aggregations/src/yamlRestTest/resources/rest-api-spec/test/aggregations/random_sampler.yml

+60
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,66 @@ setup:
142142
}
143143
- match: { aggregations.sampled.mean.value: 1.0 }
144144
---
145+
"Test random_sampler aggregation with scored subagg":
146+
- requires:
147+
capabilities:
148+
- method: POST
149+
path: /_search
150+
capabilities: [ random_sampler_with_scored_subaggs ]
151+
test_runner_features: capabilities
152+
reason: "Support for random sampler with scored subaggs capability required"
153+
- do:
154+
search:
155+
index: data
156+
size: 0
157+
body: >
158+
{
159+
"query": {
160+
"function_score": {
161+
"random_score": {}
162+
}
163+
},
164+
"aggs": {
165+
"sampled": {
166+
"random_sampler": {
167+
"probability": 0.5
168+
},
169+
"aggs": {
170+
"top": {
171+
"top_hits": {}
172+
}
173+
}
174+
}
175+
}
176+
}
177+
- is_true: aggregations.sampled.top.hits
178+
- do:
179+
search:
180+
index: data
181+
size: 0
182+
body: >
183+
{
184+
"query": {
185+
"function_score": {
186+
"random_score": {}
187+
}
188+
},
189+
"aggs": {
190+
"sampled": {
191+
"random_sampler": {
192+
"probability": 1.0
193+
},
194+
"aggs": {
195+
"top": {
196+
"top_hits": {}
197+
}
198+
}
199+
}
200+
}
201+
}
202+
- match: { aggregations.sampled.top.hits.total.value: 6 }
203+
- is_true: aggregations.sampled.top.hits.hits.0._score
204+
---
145205
"Test random_sampler aggregation with poor settings":
146206
- requires:
147207
cluster_features: ["gte_v8.2.0"]

server/src/main/java/org/elasticsearch/rest/action/search/SearchCapabilities.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ private SearchCapabilities() {}
2323
/** Support synthetic source with `bit` type in `dense_vector` field when `index` is set to `false`. */
2424
private static final String BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY = "bit_dense_vector_synthetic_source";
2525
private static final String NESTED_RETRIEVER_INNER_HITS_SUPPORT = "nested_retriever_inner_hits_support";
26+
private static final String RANDOM_SAMPLER_WITH_SCORED_SUBAGGS = "random_sampler_with_scored_subaggs";
2627

2728
public static final Set<String> CAPABILITIES = Set.of(
2829
RANGE_REGEX_INTERVAL_QUERY_CAPABILITY,
2930
BIT_DENSE_VECTOR_SYNTHETIC_SOURCE_CAPABILITY,
30-
NESTED_RETRIEVER_INNER_HITS_SUPPORT
31+
NESTED_RETRIEVER_INNER_HITS_SUPPORT,
32+
RANDOM_SAMPLER_WITH_SCORED_SUBAGGS
3133
);
3234
}

server/src/main/java/org/elasticsearch/search/aggregations/AggregatorBase.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public abstract class AggregatorBase extends Aggregator {
4040

4141
protected final String name;
4242
protected final Aggregator parent;
43-
private final AggregationContext context;
43+
protected final AggregationContext context;
4444
private final Map<String, Object> metadata;
4545

4646
protected final Aggregator[] subAggregators;

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregator.java

+41-12
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,22 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12+
import org.apache.lucene.search.BooleanClause;
13+
import org.apache.lucene.search.BooleanQuery;
1214
import org.apache.lucene.search.CollectionTerminatedException;
1315
import org.apache.lucene.search.DocIdSetIterator;
16+
import org.apache.lucene.search.Query;
17+
import org.apache.lucene.search.ScoreMode;
1418
import org.apache.lucene.search.Scorer;
1519
import org.apache.lucene.search.Weight;
1620
import org.apache.lucene.util.Bits;
17-
import org.elasticsearch.common.CheckedSupplier;
1821
import org.elasticsearch.search.aggregations.AggregationExecutionContext;
1922
import org.elasticsearch.search.aggregations.Aggregator;
2023
import org.elasticsearch.search.aggregations.AggregatorFactories;
2124
import org.elasticsearch.search.aggregations.CardinalityUpperBound;
2225
import org.elasticsearch.search.aggregations.InternalAggregation;
2326
import org.elasticsearch.search.aggregations.LeafBucketCollector;
27+
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
2428
import org.elasticsearch.search.aggregations.bucket.BucketsAggregator;
2529
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
2630
import org.elasticsearch.search.aggregations.support.AggregationContext;
@@ -33,14 +37,13 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
3337
private final int seed;
3438
private final Integer shardSeed;
3539
private final double probability;
36-
private final CheckedSupplier<Weight, IOException> weightSupplier;
40+
private Weight weight;
3741

3842
RandomSamplerAggregator(
3943
String name,
4044
int seed,
4145
Integer shardSeed,
4246
double probability,
43-
CheckedSupplier<Weight, IOException> weightSupplier,
4447
AggregatorFactories factories,
4548
AggregationContext context,
4649
Aggregator parent,
@@ -55,10 +58,33 @@ public class RandomSamplerAggregator extends BucketsAggregator implements Single
5558
RandomSamplerAggregationBuilder.NAME + " aggregation [" + name + "] must have sub aggregations configured"
5659
);
5760
}
58-
this.weightSupplier = weightSupplier;
5961
this.shardSeed = shardSeed;
6062
}
6163

64+
/**
65+
* This creates the query weight which will be used in the aggregator.
66+
*
67+
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
68+
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
69+
* @return weight to be used, is cached for additional usages
70+
* @throws IOException when building the weight or queries fails;
71+
*/
72+
private Weight getWeight() throws IOException {
73+
if (weight == null) {
74+
ScoreMode scoreMode = scoreMode();
75+
BooleanQuery.Builder fullQuery = new BooleanQuery.Builder().add(
76+
context.query(),
77+
scoreMode.needsScores() ? BooleanClause.Occur.MUST : BooleanClause.Occur.FILTER
78+
);
79+
if (probability < 1.0) {
80+
Query sampleQuery = new RandomSamplingQuery(probability, seed, shardSeed == null ? context.shardRandomSeed() : shardSeed);
81+
fullQuery.add(sampleQuery, BooleanClause.Occur.FILTER);
82+
}
83+
weight = context.searcher().createWeight(context.searcher().rewrite(fullQuery.build()), scoreMode, 1f);
84+
}
85+
return weight;
86+
}
87+
6288
@Override
6389
public InternalAggregation[] buildAggregations(long[] owningBucketOrds) throws IOException {
6490
return buildAggregationsForSingleBucket(
@@ -100,22 +126,26 @@ protected LeafBucketCollector getLeafCollector(AggregationExecutionContext aggCt
100126
if (sub.isNoop()) {
101127
return LeafBucketCollector.NO_OP_COLLECTOR;
102128
}
129+
130+
Scorer scorer = getWeight().scorer(aggCtx.getLeafReaderContext());
131+
// This means there are no docs to iterate, possibly due to the fields not existing
132+
if (scorer == null) {
133+
return LeafBucketCollector.NO_OP_COLLECTOR;
134+
}
135+
sub.setScorer(scorer);
136+
103137
// No sampling is being done, collect all docs
138+
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
104139
if (probability >= 1.0) {
105140
grow(1);
106-
return new LeafBucketCollector() {
141+
return new LeafBucketCollectorBase(sub, null) {
107142
@Override
108143
public void collect(int doc, long owningBucketOrd) throws IOException {
109144
collectExistingBucket(sub, doc, 0);
110145
}
111146
};
112147
}
113-
// TODO know when sampling would be much slower and skip sampling: https://github.com/elastic/elasticsearch/issues/84353
114-
Scorer scorer = weightSupplier.get().scorer(aggCtx.getLeafReaderContext());
115-
// This means there are no docs to iterate, possibly due to the fields not existing
116-
if (scorer == null) {
117-
return LeafBucketCollector.NO_OP_COLLECTOR;
118-
}
148+
119149
final DocIdSetIterator docIt = scorer.iterator();
120150
final Bits liveDocs = aggCtx.getLeafReaderContext().reader().getLiveDocs();
121151
try {
@@ -135,5 +165,4 @@ public void collect(int doc, long owningBucketOrd) throws IOException {
135165
// Since we have done our own collection, there is nothing for the leaf collector to do
136166
return LeafBucketCollector.NO_OP_COLLECTOR;
137167
}
138-
139168
}

server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorFactory.java

+1-41
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99

1010
package org.elasticsearch.search.aggregations.bucket.sampler.random;
1111

12-
import org.apache.lucene.search.BooleanClause;
13-
import org.apache.lucene.search.BooleanQuery;
14-
import org.apache.lucene.search.ScoreMode;
15-
import org.apache.lucene.search.Weight;
1612
import org.elasticsearch.search.aggregations.Aggregator;
1713
import org.elasticsearch.search.aggregations.AggregatorFactories;
1814
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -30,7 +26,6 @@ public class RandomSamplerAggregatorFactory extends AggregatorFactory {
3026
private final Integer shardSeed;
3127
private final double probability;
3228
private final SamplingContext samplingContext;
33-
private Weight weight;
3429

3530
RandomSamplerAggregatorFactory(
3631
String name,
@@ -57,41 +52,6 @@ public Optional<SamplingContext> getSamplingContext() {
5752
@Override
5853
public Aggregator createInternal(Aggregator parent, CardinalityUpperBound cardinality, Map<String, Object> metadata)
5954
throws IOException {
60-
return new RandomSamplerAggregator(
61-
name,
62-
seed,
63-
shardSeed,
64-
probability,
65-
this::getWeight,
66-
factories,
67-
context,
68-
parent,
69-
cardinality,
70-
metadata
71-
);
55+
return new RandomSamplerAggregator(name, seed, shardSeed, probability, factories, context, parent, cardinality, metadata);
7256
}
73-
74-
/**
75-
* This creates the query weight which will be used in the aggregator.
76-
*
77-
* This weight is a boolean query between {@link RandomSamplingQuery} and the configured top level query of the search. This allows
78-
* the aggregation to iterate the documents directly, thus sampling in the background instead of the foreground.
79-
* @return weight to be used, is cached for additional usages
80-
* @throws IOException when building the weight or queries fails;
81-
*/
82-
private Weight getWeight() throws IOException {
83-
if (weight == null) {
84-
RandomSamplingQuery query = new RandomSamplingQuery(
85-
probability,
86-
seed,
87-
shardSeed == null ? context.shardRandomSeed() : shardSeed
88-
);
89-
BooleanQuery booleanQuery = new BooleanQuery.Builder().add(query, BooleanClause.Occur.FILTER)
90-
.add(context.query(), BooleanClause.Occur.FILTER)
91-
.build();
92-
weight = context.searcher().createWeight(context.searcher().rewrite(booleanQuery), ScoreMode.COMPLETE_NO_SCORES, 1f);
93-
}
94-
return weight;
95-
}
96-
9757
}

server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplerAggregatorTests.java

+38
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,29 @@
1111

1212
import org.apache.lucene.document.LongPoint;
1313
import org.apache.lucene.document.SortedNumericDocValuesField;
14+
import org.apache.lucene.index.Term;
15+
import org.apache.lucene.search.BooleanClause;
16+
import org.apache.lucene.search.BooleanQuery;
17+
import org.apache.lucene.search.TermQuery;
1418
import org.apache.lucene.tests.index.RandomIndexWriter;
1519
import org.apache.lucene.util.BytesRef;
1620
import org.elasticsearch.common.Strings;
1721
import org.elasticsearch.index.mapper.KeywordFieldMapper;
1822
import org.elasticsearch.index.query.QueryBuilders;
23+
import org.elasticsearch.search.SearchHit;
1924
import org.elasticsearch.search.aggregations.AggregationBuilders;
2025
import org.elasticsearch.search.aggregations.AggregatorTestCase;
2126
import org.elasticsearch.search.aggregations.bucket.filter.Filter;
2227
import org.elasticsearch.search.aggregations.metrics.Avg;
2328
import org.elasticsearch.search.aggregations.metrics.Max;
2429
import org.elasticsearch.search.aggregations.metrics.Min;
30+
import org.elasticsearch.search.aggregations.metrics.TopHits;
2531
import org.hamcrest.Description;
2632
import org.hamcrest.Matcher;
2733
import org.hamcrest.TypeSafeMatcher;
2834

2935
import java.io.IOException;
36+
import java.util.Arrays;
3037
import java.util.List;
3138
import java.util.concurrent.atomic.AtomicInteger;
3239
import java.util.stream.DoubleStream;
@@ -37,6 +44,8 @@
3744
import static org.hamcrest.Matchers.equalTo;
3845
import static org.hamcrest.Matchers.greaterThan;
3946
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
47+
import static org.hamcrest.Matchers.hasSize;
48+
import static org.hamcrest.Matchers.lessThan;
4049
import static org.hamcrest.Matchers.lessThanOrEqualTo;
4150
import static org.hamcrest.Matchers.not;
4251
import static org.hamcrest.Matchers.notANumber;
@@ -76,6 +85,35 @@ public void testAggregationSampling() throws IOException {
7685
assertThat(avgAvg, closeTo(1.5, 0.5));
7786
}
7887

88+
public void testAggregationSampling_withScores() throws IOException {
89+
long[] counts = new long[5];
90+
AtomicInteger integer = new AtomicInteger();
91+
do {
92+
testCase(RandomSamplerAggregatorTests::writeTestDocs, (InternalRandomSampler result) -> {
93+
counts[integer.get()] = result.getDocCount();
94+
if (result.getDocCount() > 0) {
95+
TopHits agg = result.getAggregations().get("top");
96+
List<SearchHit> hits = Arrays.asList(agg.getHits().getHits());
97+
assertThat(Strings.toString(result), hits, hasSize(1));
98+
assertThat(Strings.toString(result), hits.get(0).getScore(), allOf(greaterThan(0.0f), lessThan(1.0f)));
99+
}
100+
},
101+
new AggTestConfig(
102+
new RandomSamplerAggregationBuilder("my_agg").subAggregation(AggregationBuilders.topHits("top").size(1))
103+
.setProbability(0.25),
104+
longField(NUMERIC_FIELD_NAME)
105+
).withQuery(
106+
new BooleanQuery.Builder().add(
107+
new TermQuery(new Term(KEYWORD_FIELD_NAME, KEYWORD_FIELD_VALUE)),
108+
BooleanClause.Occur.SHOULD
109+
).build()
110+
)
111+
);
112+
} while (integer.incrementAndGet() < 5);
113+
long avgCount = LongStream.of(counts).sum() / integer.get();
114+
assertThat(avgCount, allOf(greaterThanOrEqualTo(20L), lessThanOrEqualTo(70L)));
115+
}
116+
79117
public void testAggregationSamplingNestedAggsScaled() throws IOException {
80118
// in case 0 docs get sampled, which can rarely happen
81119
// in case the test index has many segments.

0 commit comments

Comments
 (0)