Skip to content

Commit ccf2417

Browse files
committed
Pass script level params into scripted metric aggs (elastic#28819)
Now params that are passed at the script level and at the aggregation level are merged and can both be used in the aggregation scripts. If there are any conflicts, aggregation level params will win. This may be followed by another change detecting that case and throwing an exception to disallow such conflicts.
1 parent 0abf51a commit ccf2417

File tree

4 files changed

+103
-24
lines changed

4 files changed

+103
-24
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.search.internal.SearchContext;
3838

3939
import java.io.IOException;
40+
import java.util.Collections;
4041
import java.util.Map;
4142
import java.util.Objects;
4243

@@ -198,20 +199,34 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega
198199
Builder subfactoriesBuilder) throws IOException {
199200

200201
QueryShardContext queryShardContext = context.getQueryShardContext();
202+
203+
// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
204+
// access to them for the scripts it's given precompiled.
205+
201206
ExecutableScript.Factory executableInitScript;
207+
Map<String, Object> initScriptParams;
202208
if (initScript != null) {
203209
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
210+
initScriptParams = initScript.getParams();
204211
} else {
205212
executableInitScript = p -> null;
213+
initScriptParams = Collections.emptyMap();
206214
}
215+
207216
SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
217+
Map<String, Object> mapScriptParams = mapScript.getParams();
218+
208219
ExecutableScript.Factory executableCombineScript;
220+
Map<String, Object> combineScriptParams;
209221
if (combineScript != null) {
210-
executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
222+
executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
223+
combineScriptParams = combineScript.getParams();
211224
} else {
212225
executableCombineScript = p -> null;
226+
combineScriptParams = Collections.emptyMap();
213227
}
214-
return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript,
228+
return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
229+
executableCombineScript, combineScriptParams, reduceScript,
215230
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
216231
}
217232

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java

+33-15
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,35 @@
3535
import java.util.HashMap;
3636
import java.util.List;
3737
import java.util.Map;
38-
import java.util.function.Function;
3938

4039
public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {
4140

4241
private final SearchScript.Factory mapScript;
42+
private final Map<String, Object> mapScriptParams;
4343
private final ExecutableScript.Factory combineScript;
44+
private final Map<String, Object> combineScriptParams;
4445
private final Script reduceScript;
45-
private final Map<String, Object> params;
46+
private final Map<String, Object> aggParams;
4647
private final SearchLookup lookup;
4748
private final ExecutableScript.Factory initScript;
49+
private final Map<String, Object> initScriptParams;
4850

49-
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript,
50-
ExecutableScript.Factory combineScript, Script reduceScript, Map<String, Object> params,
51+
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
52+
ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
53+
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
54+
Script reduceScript, Map<String, Object> aggParams,
5155
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
5256
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
5357
super(name, context, parent, subFactories, metaData);
5458
this.mapScript = mapScript;
59+
this.mapScriptParams = mapScriptParams;
5560
this.initScript = initScript;
61+
this.initScriptParams = initScriptParams;
5662
this.combineScript = combineScript;
63+
this.combineScriptParams = combineScriptParams;
5764
this.reduceScript = reduceScript;
5865
this.lookup = lookup;
59-
this.params = params;
66+
this.aggParams = aggParams;
6067
}
6168

6269
@Override
@@ -65,26 +72,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
6572
if (collectsFromSingleBucket == false) {
6673
return asMultiBucketAggregator(this, context, parent);
6774
}
68-
Map<String, Object> params = this.params;
69-
if (params != null) {
70-
params = deepCopyParams(params, context);
75+
Map<String, Object> aggParams = this.aggParams;
76+
if (aggParams != null) {
77+
aggParams = deepCopyParams(aggParams, context);
7178
} else {
72-
params = new HashMap<>();
79+
aggParams = new HashMap<>();
7380
}
74-
if (params.containsKey("_agg") == false) {
75-
params.put("_agg", new HashMap<String, Object>());
81+
if (aggParams.containsKey("_agg") == false) {
82+
aggParams.put("_agg", new HashMap<String, Object>());
7683
}
7784

78-
final ExecutableScript initScript = this.initScript.newInstance(params);
79-
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup);
80-
final ExecutableScript combineScript = this.combineScript.newInstance(params);
85+
final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
86+
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
87+
final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));
8188

8289
final Script reduceScript = deepCopyScript(this.reduceScript, context);
8390
if (initScript != null) {
8491
initScript.run();
8592
}
8693
return new ScriptedMetricAggregator(name, mapScript,
87-
combineScript, reduceScript, params, context, parent,
94+
combineScript, reduceScript, aggParams, context, parent,
8895
pipelineAggregators, metaData);
8996
}
9097

@@ -128,5 +135,16 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
128135
return clone;
129136
}
130137

138+
private static Map<String, Object> mergeParams(Map<String, Object> agg, Map<String, Object> script) {
139+
// TODO Should we throw an exception when param names conflict between aggregation and script? Need to add test coverage
140+
// for error or override behavior depending on the decision. Should this check be added at call time or at
141+
// construction?
131142

143+
// Aggregation level commands need to win in case of conflict so that params can keep the same identity and
144+
// content across all the scripts that are run in the aggregation.
145+
Map<String, Object> combined = new HashMap<>();
146+
combined.putAll(script);
147+
combined.putAll(agg);
148+
return combined;
149+
}
132150
}

server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -361,17 +361,17 @@ public void testMapWithParams() {
361361
}
362362

363363
public void testMapWithParamsAndImplicitAggMap() {
364-
Map<String, Object> params = new HashMap<>();
365-
// don't put any _agg map in params
366-
params.put("param1", "12");
367-
params.put("param2", 1);
364+
// Split the params up between the script and the aggregation.
365+
// Don't put any _agg map in params.
366+
Map<String, Object> scriptParams = Collections.singletonMap("param1", "12");
367+
Map<String, Object> aggregationParams = Collections.singletonMap("param2", 1);
368368

369369
// The _agg hashmap will be available even if not declared in the params map
370-
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params);
370+
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams);
371371

372372
SearchResponse response = client().prepareSearch("idx")
373373
.setQuery(matchAllQuery())
374-
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript))
374+
.addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript))
375375
.get();
376376
assertSearchResponse(response);
377377
assertThat(response.getHits().getTotalHits(), equalTo(numDocs));

server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java

+47-1
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,15 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
6464
Collections.emptyMap());
6565
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
6666
Collections.emptyMap());
67-
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
6867

68+
private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams",
69+
Collections.singletonMap("initialValue", 24));
70+
private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
71+
Collections.singletonMap("itemValue", 12));
72+
private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
73+
Collections.singletonMap("divisor", 4));
74+
75+
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
6976

7077
@BeforeClass
7178
@SuppressWarnings("unchecked")
@@ -99,6 +106,26 @@ public static void initMockScripts() {
99106
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
100107
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
101108
});
109+
110+
SCRIPTS.put("initScriptParams", params -> {
111+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
112+
Integer initialValue = (Integer)params.get("initialValue");
113+
ArrayList<Integer> collector = new ArrayList();
114+
collector.add(initialValue);
115+
agg.put("collector", collector);
116+
return agg;
117+
});
118+
SCRIPTS.put("mapScriptParams", params -> {
119+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
120+
Integer itemValue = (Integer) params.get("itemValue");
121+
((List<Integer>) agg.get("collector")).add(itemValue);
122+
return agg;
123+
});
124+
SCRIPTS.put("combineScriptParams", params -> {
125+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
126+
int divisor = ((Integer) params.get("divisor"));
127+
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
128+
});
102129
}
103130

104131
@SuppressWarnings("unchecked")
@@ -187,6 +214,25 @@ public void testScriptedMetricWithCombineAccessesScores() throws IOException {
187214
}
188215
}
189216

217+
public void testScriptParamsPassedThrough() throws IOException {
218+
try (Directory directory = newDirectory()) {
219+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
220+
for (int i = 0; i < 100; i++) {
221+
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
222+
}
223+
}
224+
225+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
226+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
227+
aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS);
228+
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
229+
230+
// The result value depends on the script params.
231+
assertEquals(306, scriptedMetric.aggregation());
232+
}
233+
}
234+
}
235+
190236
/**
191237
* We cannot use Mockito for mocking QueryShardContext in this case because
192238
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)

0 commit comments

Comments
 (0)