diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactory.java b/core/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactory.java index 680e3ef2e9255..3af46c7e5bfc2 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactory.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/AggregatorFactory.java @@ -23,13 +23,15 @@ import org.elasticsearch.common.lease.Releasables; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.ObjectArray; +import org.elasticsearch.script.*; +import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; import org.elasticsearch.search.internal.SearchContext.Lifetime; +import org.elasticsearch.transport.TransportRequest; import java.io.IOException; -import java.util.List; -import java.util.Map; +import java.util.*; /** * A factory that knows how to create an {@link Aggregator} of a specific type. @@ -234,4 +236,143 @@ public void close() { }; } + // Specific implementation used by ScriptedMetricAggregator + // In this context we need to provide the reduce function to implement the metric method inherited + // from the NumericMetricsAggregator class. + protected static Aggregator asMultiBucketAggregator(final AggregatorFactory factory, + final AggregationContext context, + final Aggregator parent, + Script reduceScript) throws IOException { + final Aggregator first = factory.create(context, parent, true); + final BigArrays bigArrays = context.bigArrays(); + + first.preCollection(); + + // Returns a NumericMetricsAggregator instead of a simple Aggregator. So the result of the aggregation + // can be used to order the result. + // It's NumericMetricsAggregator.SingleValue because the getProperty method in InternalScriptedMetric + // only supports single value (path: value). + return new NumericMetricsAggregator.SingleValue("",context, parent, new ArrayList<>(), null) { + @Override + public double metric(final long owningBucketOrd) { + try { + Object aggregationObject = buildAggregation(owningBucketOrd).getProperty("value"); + List aggregationObjects = Arrays.asList(aggregationObject); + Map vars = new HashMap<>(); + vars.put("_aggs", aggregationObjects); + + if (reduceScript.getParams() != null) { + vars.putAll(reduceScript.getParams()); + } + ScriptService scriptService = context().searchContext().scriptService(); + CompiledScript compiledScript = scriptService.compile(reduceScript, + ScriptContext.Standard.AGGS, new InternalAggregation.ReduceContext(bigArrays, scriptService, new TransportRequest.Empty())); + ExecutableScript script = scriptService.executable(compiledScript, vars); + Object value = script.run(); + + if(value instanceof Number) { + return ((Number) value).doubleValue(); + } else { + throw new AggregationExecutionException("Invalid order path ["+this+ + "]. Only numeric result are supported."); + } + } catch (IOException e) { + throw new AggregationExecutionException("Failed to build aggregation [" + name() + "]", e); + } + } + + @Override + protected LeafBucketCollector getLeafCollector(final LeafReaderContext ctx, final LeafBucketCollector sub) throws IOException { + for (long i = 0; i < collectors.size(); ++i) { + collectors.set(i, null); + } + return new LeafBucketCollector() { + Scorer scorer; + + @Override + public void setScorer(Scorer scorer) throws IOException { + this.scorer = scorer; + } + + @Override + public void collect(int doc, long bucket) throws IOException { + aggregators = bigArrays.grow(aggregators, bucket + 1); + collectors = bigArrays.grow(collectors, bucket + 1); + + LeafBucketCollector collector = collectors.get(bucket); + if (collector == null) { + Aggregator aggregator = aggregators.get(bucket); + if (aggregator == null) { + aggregator = factory.create(context, parent, true); + aggregator.preCollection(); + aggregators.set(bucket, aggregator); + } + collector = aggregator.getLeafCollector(ctx); + collector.setScorer(scorer); + collectors.set(bucket, collector); + } + collector.collect(doc, 0); + } + + }; + } + + ObjectArray aggregators; + ObjectArray collectors; + + { + context.searchContext().addReleasable(this, Lifetime.PHASE); + aggregators = bigArrays.newObjectArray(1); + aggregators.set(0, first); + collectors = bigArrays.newObjectArray(1); + } + + @Override + public String name() { + return first.name(); + } + + @Override + public AggregationContext context() { + return first.context(); + } + + @Override + public Aggregator parent() { + return first.parent(); + } + + @Override + public boolean needsScores() { + return first.needsScores(); + } + + @Override + public Aggregator subAggregator(String name) { + throw new UnsupportedOperationException(); + } + + @Override + public InternalAggregation buildAggregation(long bucket) throws IOException { + if (bucket < aggregators.size()) { + Aggregator aggregator = aggregators.get(bucket); + if (aggregator != null) { + return aggregator.buildAggregation(0); + } + } + return buildEmptyAggregation(); + } + + @Override + public InternalAggregation buildEmptyAggregation() { + return first.buildEmptyAggregation(); + } + + @Override + public void close() { + Releasables.close(aggregators, collectors); + } + }; + } + } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java index 67a6f19f72cbc..998f28975ee73 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/avg/AvgAggregator.java @@ -24,10 +24,7 @@ import org.elasticsearch.common.util.DoubleArray; import org.elasticsearch.common.util.LongArray; import org.elasticsearch.index.fielddata.SortedNumericDoubleValues; -import org.elasticsearch.search.aggregations.Aggregator; -import org.elasticsearch.search.aggregations.InternalAggregation; -import org.elasticsearch.search.aggregations.LeafBucketCollector; -import org.elasticsearch.search.aggregations.LeafBucketCollectorBase; +import org.elasticsearch.search.aggregations.*; import org.elasticsearch.search.aggregations.metrics.NumericMetricsAggregator; import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; @@ -35,6 +32,7 @@ import org.elasticsearch.search.aggregations.support.ValuesSourceAggregatorFactory; import org.elasticsearch.search.aggregations.support.ValuesSourceConfig; import org.elasticsearch.search.aggregations.support.format.ValueFormatter; +import org.elasticsearch.search.aggregations.support.values.ScriptDoubleValues; import java.io.IOException; import java.util.List; @@ -52,8 +50,8 @@ public class AvgAggregator extends NumericMetricsAggregator.SingleValue { ValueFormatter formatter; public AvgAggregator(String name, ValuesSource.Numeric valuesSource, ValueFormatter formatter, - AggregationContext context, - Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { + AggregationContext context, + Aggregator parent, List pipelineAggregators, Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); this.valuesSource = valuesSource; this.formatter = formatter; @@ -71,7 +69,7 @@ public boolean needsScores() { @Override public LeafBucketCollector getLeafCollector(LeafReaderContext ctx, - final LeafBucketCollector sub) throws IOException { + final LeafBucketCollector sub) throws IOException { if (valuesSource == null) { return LeafBucketCollector.NO_OP_COLLECTOR; } @@ -85,11 +83,47 @@ public void collect(int doc, long bucket) throws IOException { values.setDocument(doc); final int valueCount = values.count(); - counts.increment(bucket, valueCount); double sum = 0; - for (int i = 0; i < valueCount; i++) { - sum += values.valueAt(i); + long count = 0; + if (values instanceof ScriptDoubleValues && ((ScriptDoubleValues)values).hasExtendedScriptResult()) { + // Avg aggregation with a script returning a weight or a list of weights + ScriptDoubleValues scriptDoubleValues = (ScriptDoubleValues)values; + Object weightObj = scriptDoubleValues.getExtendedScriptResult().get("weight"); + Object weightsObj = scriptDoubleValues.getExtendedScriptResult().get("weights"); + if(weightsObj instanceof List) { + List weights = (List) weightsObj; + for (int i = 0; i < valueCount; i++) { + // By default, a missing weight is considered equals to 1. + long weight = 1; + if(i < weights.size()) { + Object weightObjAtIndexI = weights.get(i); + if(weightObjAtIndexI instanceof Number) { + weight = ((Number)weightObjAtIndexI).longValue(); + } else { + throw new AggregationExecutionException("Unsupported weight value [" + weightObjAtIndexI + "], should be a number"); + } + } + count += weight; + sum += values.valueAt(i) * weight; + } + } else { + // by default weight is equals to 0 + long weight = 1; + if(weightObj instanceof Number) { + weight = ((Number)weightObj).longValue(); + } + for (int i = 0; i < valueCount; i++) { + count += weight; + sum += values.valueAt(i) * weight; + } + } + } else { + count += valueCount; + for (int i = 0; i < valueCount; i++) { + sum += values.valueAt(i); + } } + counts.increment(bucket, count); sums.increment(bucket, sum); } }; @@ -121,14 +155,14 @@ public Factory(String name, String type, ValuesSourceConfig pipelineAggregators, - Map metaData) throws IOException { + List pipelineAggregators, + Map metaData) throws IOException { return new AvgAggregator(name, null, config.formatter(), aggregationContext, parent, pipelineAggregators, metaData); } @Override protected Aggregator doCreateInternal(ValuesSource.Numeric valuesSource, AggregationContext aggregationContext, Aggregator parent, - boolean collectsFromSingleBucket, List pipelineAggregators, Map metaData) + boolean collectsFromSingleBucket, List pipelineAggregators, Map metaData) throws IOException { return new AvgAggregator(name, valuesSource, config.formatter(), aggregationContext, parent, pipelineAggregators, metaData); } diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java index 2c1caaa5241c5..ef1baf6020fd6 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java @@ -127,7 +127,7 @@ public Factory(String name, Script initScript, Script mapScript, Script combineS public Aggregator createInternal(AggregationContext context, Aggregator parent, boolean collectsFromSingleBucket, List pipelineAggregators, Map metaData) throws IOException { if (collectsFromSingleBucket == false) { - return asMultiBucketAggregator(this, context, parent); + return asMultiBucketAggregator(this, context, parent, reduceScript); } Map params = this.params; if (params != null) { diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/support/AggregationPath.java b/core/src/main/java/org/elasticsearch/search/aggregations/support/AggregationPath.java index 84fd26a74f351..ec052a835ab30 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/support/AggregationPath.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/support/AggregationPath.java @@ -20,10 +20,7 @@ package org.elasticsearch.search.aggregations.support; import org.elasticsearch.common.Strings; -import org.elasticsearch.search.aggregations.Aggregation; -import org.elasticsearch.search.aggregations.AggregationExecutionException; -import org.elasticsearch.search.aggregations.Aggregator; -import org.elasticsearch.search.aggregations.HasAggregations; +import org.elasticsearch.search.aggregations.*; import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregation; import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator; import org.elasticsearch.search.aggregations.metrics.InternalNumericMetricsAggregation; @@ -240,7 +237,24 @@ public double resolveValue(HasAggregations root) { "]. Missing value key in [" + token + "] which refers to a multi-value metric aggregation"); } parent = null; - value = ((InternalNumericMetricsAggregation.MultiValue) agg).value(token.key); + if(agg instanceof InternalNumericMetricsAggregation.MultiValue) { + // For NumericMetricsAggregation we use the method value returning a native double + // Optimization to avoid object creation and multiple casts. + value = ((InternalNumericMetricsAggregation.MultiValue) agg).value(token.key); + } else if(agg instanceof InternalAggregation) { + // For a general use case, we use the method getProperty returning an Object + Object propertyValue = agg.getProperty(token.key); + // Only aggregation returning a numeric value are supported. + if(propertyValue instanceof Number) { + value = ((Number) propertyValue).doubleValue(); + } else { + throw new AggregationExecutionException("Invalid order path ["+this+ + "]. Only numeric result are supported."); + } + } else { + throw new AggregationExecutionException("Invalid aggregation type for order path ["+this+ + "]. Only numeric & scripted metric aggregation are supported."); + } } return value; diff --git a/core/src/main/java/org/elasticsearch/search/aggregations/support/values/ScriptDoubleValues.java b/core/src/main/java/org/elasticsearch/search/aggregations/support/values/ScriptDoubleValues.java index ee9e1272e8310..0092d31d1ea5b 100644 --- a/core/src/main/java/org/elasticsearch/search/aggregations/support/values/ScriptDoubleValues.java +++ b/core/src/main/java/org/elasticsearch/search/aggregations/support/values/ScriptDoubleValues.java @@ -27,6 +27,7 @@ import java.lang.reflect.Array; import java.util.Collection; import java.util.Iterator; +import java.util.Map; /** * {@link SortingNumericDoubleValues} implementation which is based on a script @@ -35,6 +36,8 @@ public class ScriptDoubleValues extends SortingNumericDoubleValues implements Sc final LeafSearchScript script; + Map extendedScriptResult; + public ScriptDoubleValues(LeafSearchScript script) { super(); this.script = script; @@ -68,6 +71,26 @@ else if (value instanceof Collection) { values[i] = ((Number) it.next()).doubleValue(); } assert i == count(); + } else if(value instanceof Map) { + // Map containing one or several values + some optional extended properties. + // Aggregators can use these extended properties to implement specific behaviors (eg: AvgAggregator using + // the property weight(s) to compute a weighted average instead of standard average). + extendedScriptResult = (Map)value; + Object scriptValue = extendedScriptResult.remove("value"); + Object scriptValues = extendedScriptResult.remove("values"); + if(scriptValues != null && scriptValues instanceof Collection) { + resize(((Collection)scriptValues).size()); + int i = 0; + for (Iterator it = ((Collection) scriptValues).iterator(); it.hasNext(); ++i) { + values[i] = ((Number) it.next()).doubleValue(); + } + assert i == count(); + } else if(scriptValue != null && scriptValue instanceof Number) { + resize(1); + values[0] = ((Number)scriptValue).doubleValue(); + } else { + throw new AggregationExecutionException("Unsupported script value [" + value + "]"); + } } else { @@ -77,6 +100,14 @@ else if (value instanceof Collection) { sort(); } + public Map getExtendedScriptResult() { + return extendedScriptResult; + } + + public final boolean hasExtendedScriptResult() { + return extendedScriptResult != null && !extendedScriptResult.isEmpty(); + } + @Override public void setScorer(Scorer scorer) { script.setScorer(scorer);