Skip to content

Ordering term aggregation based on scripted metric. #15718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import org.elasticsearch.common.lease.Releasables;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.script.*;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.internal.SearchContext.Lifetime;
import org.elasticsearch.transport.TransportRequest;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.*;

/**
* A factory that knows how to create an {@link Aggregator} of a specific type.
Expand Down Expand Up @@ -234,4 +236,143 @@ public void close() {
};
}

// Specific implementation used by ScriptedMetricAggregator
// In this context we need to provide the reduce function to implement the metric method inherited
// from the NumericMetricsAggregator class.
protected static Aggregator asMultiBucketAggregator(final AggregatorFactory factory,
final AggregationContext context,
final Aggregator parent,
Script reduceScript) throws IOException {
final Aggregator first = factory.create(context, parent, true);
final BigArrays bigArrays = context.bigArrays();

first.preCollection();

// Returns a NumericMetricsAggregator instead of a simple Aggregator. So the result of the aggregation
// can be used to order the result.
// It's NumericMetricsAggregator.SingleValue because the getProperty method in InternalScriptedMetric
// only supports single value (path: value).
return new NumericMetricsAggregator.SingleValue("",context, parent, new ArrayList<>(), null) {
@Override
public double metric(final long owningBucketOrd) {
try {
Object aggregationObject = buildAggregation(owningBucketOrd).getProperty("value");
List<Object> aggregationObjects = Arrays.asList(aggregationObject);
Map<String, Object> vars = new HashMap<>();
vars.put("_aggs", aggregationObjects);

if (reduceScript.getParams() != null) {
vars.putAll(reduceScript.getParams());
}
ScriptService scriptService = context().searchContext().scriptService();
CompiledScript compiledScript = scriptService.compile(reduceScript,
ScriptContext.Standard.AGGS, new InternalAggregation.ReduceContext(bigArrays, scriptService, new TransportRequest.Empty()));
ExecutableScript script = scriptService.executable(compiledScript, vars);
Object value = script.run();

if(value instanceof Number) {
return ((Number) value).doubleValue();
} else {
throw new AggregationExecutionException("Invalid order path ["+this+
"]. Only numeric result are supported.");
}
} catch (IOException e) {
throw new AggregationExecutionException("Failed to build aggregation [" + name() + "]", e);
}
}

@Override
protected LeafBucketCollector getLeafCollector(final LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException {
for (long i = 0; i < collectors.size(); ++i) {
collectors.set(i, null);
}
return new LeafBucketCollector() {
Scorer scorer;

@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}

@Override
public void collect(int doc, long bucket) throws IOException {
aggregators = bigArrays.grow(aggregators, bucket + 1);
collectors = bigArrays.grow(collectors, bucket + 1);

LeafBucketCollector collector = collectors.get(bucket);
if (collector == null) {
Aggregator aggregator = aggregators.get(bucket);
if (aggregator == null) {
aggregator = factory.create(context, parent, true);
aggregator.preCollection();
aggregators.set(bucket, aggregator);
}
collector = aggregator.getLeafCollector(ctx);
collector.setScorer(scorer);
collectors.set(bucket, collector);
}
collector.collect(doc, 0);
}

};
}

ObjectArray<Aggregator> aggregators;
ObjectArray<LeafBucketCollector> collectors;

{
context.searchContext().addReleasable(this, Lifetime.PHASE);
aggregators = bigArrays.newObjectArray(1);
aggregators.set(0, first);
collectors = bigArrays.newObjectArray(1);
}

@Override
public String name() {
return first.name();
}

@Override
public AggregationContext context() {
return first.context();
}

@Override
public Aggregator parent() {
return first.parent();
}

@Override
public boolean needsScores() {
return first.needsScores();
}

@Override
public Aggregator subAggregator(String name) {
throw new UnsupportedOperationException();
}

@Override
public InternalAggregation buildAggregation(long bucket) throws IOException {
if (bucket < aggregators.size()) {
Aggregator aggregator = aggregators.get(bucket);
if (aggregator != null) {
return aggregator.buildAggregation(0);
}
}
return buildEmptyAggregation();
}

@Override
public InternalAggregation buildEmptyAggregation() {
return first.buildEmptyAggregation();
}

@Override
public void close() {
Releasables.close(aggregators, collectors);
}
};
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@
import org.elasticsearch.common.util.DoubleArray;
import org.elasticsearch.common.util.LongArray;
import org.elasticsearch.index.fielddata.SortedNumericDoubleValues;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollectorBase;
import org.elasticsearch.search.aggregations.*;
import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.AggregationContext;
import org.elasticsearch.search.aggregations.support.ValuesSource;
import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory;
import org.elasticsearch.search.aggregations.support.ValuesSourceConfig;
import org.elasticsearch.search.aggregations.support.format.ValueFormatter;
import org.elasticsearch.search.aggregations.support.values.ScriptDoubleValues;

import java.io.IOException;
import java.util.List;
Expand All @@ -52,8 +50,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue {
ValueFormatter formatter;

public AvgAggregator(String name, ValuesSource.Numeric valuesSource, ValueFormatter formatter,
AggregationContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
AggregationContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
this.valuesSource = valuesSource;
this.formatter = formatter;
Expand All @@ -71,7 +69,7 @@ public boolean needsScores() {

@Override
public LeafBucketCollector getLeafCollector(LeafReaderContext ctx,
final LeafBucketCollector sub) throws IOException {
final LeafBucketCollector sub) throws IOException {
if (valuesSource == null) {
return LeafBucketCollector.NO_OP_COLLECTOR;
}
Expand All @@ -85,11 +83,47 @@ public void collect(int doc, long bucket) throws IOException {

values.setDocument(doc);
final int valueCount = values.count();
counts.increment(bucket, valueCount);
double sum = 0;
for (int i = 0; i < valueCount; i++) {
sum += values.valueAt(i);
long count = 0;
if (values instanceof ScriptDoubleValues && ((ScriptDoubleValues)values).hasExtendedScriptResult()) {
// Avg aggregation with a script returning a weight or a list of weights
ScriptDoubleValues scriptDoubleValues = (ScriptDoubleValues)values;
Object weightObj = scriptDoubleValues.getExtendedScriptResult().get("weight");
Object weightsObj = scriptDoubleValues.getExtendedScriptResult().get("weights");
if(weightsObj instanceof List) {
List weights = (List) weightsObj;
for (int i = 0; i < valueCount; i++) {
// By default, a missing weight is considered equals to 1.
long weight = 1;
if(i < weights.size()) {
Object weightObjAtIndexI = weights.get(i);
if(weightObjAtIndexI instanceof Number) {
weight = ((Number)weightObjAtIndexI).longValue();
} else {
throw new AggregationExecutionException("Unsupported weight value [" + weightObjAtIndexI + "], should be a number");
}
}
count += weight;
sum += values.valueAt(i) * weight;
}
} else {
// by default weight is equals to 0
long weight = 1;
if(weightObj instanceof Number) {
weight = ((Number)weightObj).longValue();
}
for (int i = 0; i < valueCount; i++) {
count += weight;
sum += values.valueAt(i) * weight;
}
}
} else {
count += valueCount;
for (int i = 0; i < valueCount; i++) {
sum += values.valueAt(i);
}
}
counts.increment(bucket, count);
sums.increment(bucket, sum);
}
};
Expand Down Expand Up @@ -121,14 +155,14 @@ public Factory(String name, String type, ValuesSourceConfig<ValuesSource.Numeric

@Override
protected Aggregator createUnmapped(AggregationContext aggregationContext, Aggregator parent,
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
List<PipelineAggregator> pipelineAggregators,
Map<String, Object> metaData) throws IOException {
return new AvgAggregator(name, null, config.formatter(), aggregationContext, parent, pipelineAggregators, metaData);
}

@Override
protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource, AggregationContext aggregationContext, Aggregator parent,
boolean collectsFromSingleBucket, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData)
boolean collectsFromSingleBucket, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData)
throws IOException {
return new AvgAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, pipelineAggregators, metaData);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ public Factory(String name, Script initScript, Script mapScript, Script combineS
public Aggregator createInternal(AggregationContext context, Aggregator parent, boolean collectsFromSingleBucket,
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
if (collectsFromSingleBucket == false) {
return asMultiBucketAggregator(this, context, parent);
return asMultiBucketAggregator(this, context, parent, reduceScript);
}
Map<String, Object> params = this.params;
if (params != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@
package org.elasticsearch.search.aggregations.support;

import org.elasticsearch.common.Strings;
import org.elasticsearch.search.aggregations.Aggregation;
import org.elasticsearch.search.aggregations.AggregationExecutionException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.HasAggregations;
import org.elasticsearch.search.aggregations.*;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation;
Expand Down Expand Up @@ -240,7 +237,24 @@ public double resolveValue(HasAggregations root) {
"]. Missing value key in [" + token + "] which refers to a multi-value metric aggregation");
}
parent = null;
value = ((InternalNumericMetricsAggregation.MultiValue) agg).value(token.key);
if(agg instanceof InternalNumericMetricsAggregation.MultiValue) {
// For NumericMetricsAggregation we use the method value returning a native double
// Optimization to avoid object creation and multiple casts.
value = ((InternalNumericMetricsAggregation.MultiValue) agg).value(token.key);
} else if(agg instanceof InternalAggregation) {
// For a general use case, we use the method getProperty returning an Object
Object propertyValue = agg.getProperty(token.key);
// Only aggregation returning a numeric value are supported.
if(propertyValue instanceof Number) {
value = ((Number) propertyValue).doubleValue();
} else {
throw new AggregationExecutionException("Invalid order path ["+this+
"]. Only numeric result are supported.");
}
} else {
throw new AggregationExecutionException("Invalid aggregation type for order path ["+this+
"]. Only numeric & scripted metric aggregation are supported.");
}
}

return value;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.lang.reflect.Array;
import java.util.Collection;
import java.util.Iterator;
import java.util.Map;

/**
* {@link SortingNumericDoubleValues} implementation which is based on a script
Expand All @@ -35,6 +36,8 @@ public class ScriptDoubleValues extends SortingNumericDoubleValues implements Sc

final LeafSearchScript script;

Map extendedScriptResult;

public ScriptDoubleValues(LeafSearchScript script) {
super();
this.script = script;
Expand Down Expand Up @@ -68,6 +71,26 @@ else if (value instanceof Collection) {
values[i] = ((Number) it.next()).doubleValue();
}
assert i == count();
} else if(value instanceof Map) {
// Map containing one or several values + some optional extended properties.
// Aggregators can use these extended properties to implement specific behaviors (eg: AvgAggregator using
// the property weight(s) to compute a weighted average instead of standard average).
extendedScriptResult = (Map)value;
Object scriptValue = extendedScriptResult.remove("value");
Object scriptValues = extendedScriptResult.remove("values");
if(scriptValues != null && scriptValues instanceof Collection) {
resize(((Collection)scriptValues).size());
int i = 0;
for (Iterator<?> it = ((Collection<?>) scriptValues).iterator(); it.hasNext(); ++i) {
values[i] = ((Number) it.next()).doubleValue();
}
assert i == count();
} else if(scriptValue != null && scriptValue instanceof Number) {
resize(1);
values[0] = ((Number)scriptValue).doubleValue();
} else {
throw new AggregationExecutionException("Unsupported script value [" + value + "]");
}
}

else {
Expand All @@ -77,6 +100,14 @@ else if (value instanceof Collection) {
sort();
}

public Map getExtendedScriptResult() {
return extendedScriptResult;
}

public final boolean hasExtendedScriptResult() {
return extendedScriptResult != null && !extendedScriptResult.isEmpty();
}

@Override
public void setScorer(Scorer scorer) {
script.setScorer(scorer);
Expand Down