Skip to content

Add BKD Optimization to Range aggregation #47712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ public Number parsePoint(byte[] value) {
return HalfFloatPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
float parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
HalfFloatPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we used 2 bytes for HalfFloat, not 4?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use HalfFloatPoint.BYTES.

}

@Override
public Float parse(XContentParser parser, boolean coerce) throws IOException {
float parsed = parser.floatValue(coerce);
Expand Down Expand Up @@ -290,6 +303,19 @@ public Number parsePoint(byte[] value) {
return FloatPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
float parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
FloatPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Float parse(XContentParser parser, boolean coerce) throws IOException {
float parsed = parser.floatValue(coerce);
Expand Down Expand Up @@ -376,6 +402,19 @@ public Number parsePoint(byte[] value) {
return DoublePoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
double parsedValue = parse(value, coerce);
byte[] bytes = new byte[Long.BYTES];
DoublePoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Long.BYTES;
}

@Override
public Double parse(XContentParser parser, boolean coerce) throws IOException {
double parsed = parser.doubleValue(coerce);
Expand Down Expand Up @@ -473,6 +512,21 @@ public Number parsePoint(byte[] value) {
return INTEGER.parsePoint(value).byteValue();
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);

// Same as integer
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Short parse(XContentParser parser, boolean coerce) throws IOException {
int value = parser.intValue(coerce);
Expand Down Expand Up @@ -534,6 +588,21 @@ public Number parsePoint(byte[] value) {
return INTEGER.parsePoint(value).shortValue();
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);

// Same as integer
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Short parse(XContentParser parser, boolean coerce) throws IOException {
return parser.shortValue(coerce);
Expand Down Expand Up @@ -591,6 +660,19 @@ public Number parsePoint(byte[] value) {
return IntPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
int parsedValue = parse(value, coerce);
byte[] bytes = new byte[Integer.BYTES];
IntPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Integer.BYTES;
}

@Override
public Integer parse(XContentParser parser, boolean coerce) throws IOException {
return parser.intValue(coerce);
Expand Down Expand Up @@ -710,6 +792,19 @@ public Number parsePoint(byte[] value) {
return LongPoint.decodeDimension(value, 0);
}

@Override
public byte[] encodePoint(Number value, boolean coerce) {
long parsedValue = parse(value, coerce);
byte[] bytes = new byte[Long.BYTES];
LongPoint.encodeDimension(parsedValue, bytes, 0);
return bytes;
}

@Override
public int bytesPerEncodedPoint() {
return Long.BYTES;
}

@Override
public Long parse(XContentParser parser, boolean coerce) throws IOException {
return parser.longValue(coerce);
Expand Down Expand Up @@ -827,6 +922,8 @@ public abstract Query rangeQuery(String field, Object lowerTerm, Object upperTer
public abstract Number parse(XContentParser parser, boolean coerce) throws IOException;
public abstract Number parse(Object value, boolean coerce);
public abstract Number parsePoint(byte[] value);
public abstract byte[] encodePoint(Number value, boolean coerce);
public abstract int bytesPerEncodedPoint();
public abstract List<Field> createFields(String name, Number value, boolean indexed,
boolean docValued, boolean stored);
Number valueForSearch(Number value) {
Expand Down Expand Up @@ -979,6 +1076,14 @@ public Number parsePoint(byte[] value) {
return type.parsePoint(value);
}

public byte[] encodePoint(Number value, boolean coerce) {
return type.encodePoint(value, coerce);
}

public int bytesPerEncodedPoint() {
return type.bytesPerEncodedPoint();
}

@Override
public boolean equals(Object o) {
if (super.equals(o) == false) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,18 +164,26 @@ public AggParseContext(String name) {

public static final AggregatorFactories EMPTY = new AggregatorFactories(new AggregatorFactory[0], new ArrayList<>());

private AggregatorFactory[] factories;
protected AggregatorFactory[] factories;
private List<PipelineAggregationBuilder> pipelineAggregatorFactories;

public static Builder builder() {
return new Builder();
}

private AggregatorFactories(AggregatorFactory[] factories, List<PipelineAggregationBuilder> pipelineAggregators) {
protected AggregatorFactories(AggregatorFactory[] factories, List<PipelineAggregationBuilder> pipelineAggregators) {
this.factories = factories;
this.pipelineAggregatorFactories = pipelineAggregators;
}

public AggregatorFactory[] getFactories() {
return factories;
}

public List<PipelineAggregationBuilder> getPipelineAggregatorFactories() {
return pipelineAggregatorFactories;
}

public List<PipelineAggregator> createPipelineAggregators() {
List<PipelineAggregator> pipelineAggregators = new ArrayList<>(this.pipelineAggregatorFactories.size());
for (PipelineAggregationBuilder factory : this.pipelineAggregatorFactories) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@
public abstract class AggregatorFactory {

public static final class MultiBucketAggregatorWrapper extends Aggregator {
private final BigArrays bigArrays;
private final Aggregator parent;
private final AggregatorFactory factory;
protected final BigArrays bigArrays;
protected final AggregatorFactory factory;
protected ObjectArray<Aggregator> aggregators;
protected ObjectArray<LeafBucketCollector> collectors;
protected final Aggregator parent;
private final Aggregator first;
ObjectArray<Aggregator> aggregators;
ObjectArray<LeafBucketCollector> collectors;

MultiBucketAggregatorWrapper(BigArrays bigArrays, SearchContext context,
Aggregator parent, AggregatorFactory factory, Aggregator first) {
Aggregator parent, AggregatorFactory factory, Aggregator first) {
this.bigArrays = bigArrays;
this.parent = parent;
this.factory = factory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.profile.Profilers;
import org.elasticsearch.search.profile.aggregation.ProfilingAggregator;

import java.io.IOException;
import java.util.List;
Expand All @@ -42,15 +44,15 @@ public class AbstractRangeAggregatorFactory<R extends Range> extends ValuesSourc
private final R[] ranges;
private final boolean keyed;

public AbstractRangeAggregatorFactory(String name,
ValuesSourceConfig<Numeric> config,
R[] ranges,
boolean keyed,
InternalRange.Factory<?, ?> rangeFactory,
QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
AbstractRangeAggregatorFactory(String name,
ValuesSourceConfig<Numeric> config,
R[] ranges,
boolean keyed,
InternalRange.Factory<?, ?> rangeFactory,
QueryShardContext queryShardContext,
AggregatorFactory parent,
AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, config, queryShardContext, parent, subFactoriesBuilder, metaData);
this.ranges = ranges;
this.keyed = keyed;
Expand All @@ -59,21 +61,47 @@ public AbstractRangeAggregatorFactory(String name,

@Override
protected Aggregator createUnmapped(SearchContext searchContext,
Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new Unmapped<>(name, ranges, keyed, config.format(), searchContext, parent, rangeFactory, pipelineAggregators, metaData);
}

@Override
protected Aggregator doCreateInternal(Numeric valuesSource,
SearchContext searchContext,
Aggregator parent,
boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new RangeAggregator(name, factories, valuesSource, config.format(), rangeFactory, ranges, keyed, searchContext, parent,
pipelineAggregators, metaData);
SearchContext searchContext,
Aggregator parent,
boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {

AggregatorFactories wrappedFactories = factories;

// If we don't have a parent, the range agg can potentially optimize by using the BKD tree. But BKD
// traversal is per-range, which means that docs are potentially called out-of-order across multiple
// ranges. To prevent this from causing problems, we create a special AggregatorFactories that
// wraps all the sub-aggs with a MultiBucketAggregatorWrapper. This effectively creates a new agg
// sub-tree for each range and prevents out-of-order problems
if (parent == null) {
wrappedFactories = new AggregatorFactories(factories.getFactories(), factories.getPipelineAggregatorFactories()) {
@Override
public Aggregator[] createSubAggregators(SearchContext searchContext, Aggregator parent) throws IOException {
Aggregator[] aggregators = new Aggregator[countAggregators()];
for (int i = 0; i < this.factories.length; ++i) {
Aggregator factory = asMultiBucketAggregator(factories[i], searchContext, parent);
Profilers profilers = factory.context().getProfilers();
if (profilers != null) {
factory = new ProfilingAggregator(factory, profilers.getAggregationProfiler());
}
aggregators[i] = factory;
}
return aggregators;
}
};
}

return new RangeAggregator(name, wrappedFactories, valuesSource, config, rangeFactory, ranges, keyed, searchContext, parent,
pipelineAggregators, metaData);
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ protected Aggregator doCreateInternal(final ValuesSource.GeoPoint valuesSource,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
DistanceSource distanceSource = new DistanceSource(valuesSource, distanceType, origin, unit);
return new RangeAggregator(name, factories, distanceSource, config.format(), rangeFactory, ranges, keyed, searchContext,
parent,
pipelineAggregators, metaData);
return new RangeAggregator(name, factories, distanceSource, config, rangeFactory, ranges, keyed, searchContext,
parent, pipelineAggregators, metaData);
}

private static class DistanceSource extends ValuesSource.Numeric {
Expand Down
Loading