diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java index 8c1ecc2a715ff..45a1c53690817 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomSamplingQuery.java @@ -32,7 +32,6 @@ public final class RandomSamplingQuery extends Query { private final double p; - private final SplittableRandom splittableRandom; private final int seed; private final int hash; @@ -49,7 +48,6 @@ public RandomSamplingQuery(double p, int seed, int hash) { this.p = p; this.seed = seed; this.hash = hash; - this.splittableRandom = new SplittableRandom(BitMixer.mix(hash, seed)); } @Override @@ -78,7 +76,7 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio @Override public Scorer scorer(LeafReaderContext context) { - final SplittableRandom random = splittableRandom.split(); + final SplittableRandom random = new SplittableRandom(BitMixer.mix(hash, seed)); int maxDoc = context.reader().maxDoc(); return new ConstantScoreScorer( this, diff --git a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomDocIDSetIteratorTests.java b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomDocIDSetIteratorTests.java index a5b9a75281144..e091a4602c441 100644 --- a/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomDocIDSetIteratorTests.java +++ b/server/src/test/java/org/elasticsearch/search/aggregations/bucket/sampler/random/RandomDocIDSetIteratorTests.java @@ -11,8 +11,12 @@ import org.apache.lucene.search.DocIdSetIterator; import org.elasticsearch.test.ESTestCase; +import java.util.ArrayList; +import java.util.List; import java.util.SplittableRandom; +import static org.hamcrest.Matchers.equalTo; + public class RandomDocIDSetIteratorTests extends ESTestCase { public void testRandomSampler() { @@ -43,4 +47,26 @@ public void testRandomSampler() { } } + public void testRandomSamplerConsistency() { + int maxDoc = 10000; + int seed = randomInt(); + + for (int i = 1; i < 100; i++) { + double p = i / 100.0; + SplittableRandom random = new SplittableRandom(seed); + List iterationOne = new ArrayList<>(); + RandomSamplingQuery.RandomSamplingIterator iter = new RandomSamplingQuery.RandomSamplingIterator(maxDoc, p, random::nextInt); + while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + iterationOne.add(iter.docID()); + } + random = new SplittableRandom(seed); + List iterationTwo = new ArrayList<>(); + iter = new RandomSamplingQuery.RandomSamplingIterator(maxDoc, p, random::nextInt); + while (iter.nextDoc() != DocIdSetIterator.NO_MORE_DOCS) { + iterationTwo.add(iter.docID()); + } + assertThat(iterationOne, equalTo(iterationTwo)); + } + } + }