Skip to content

Commit 66e6266

Browse files
rationullcolings86
authored andcommitted
Pass through script params in scripted metric agg (#29154)
* Pass script level params into scripted metric aggs (#28819) Now params that are passed at the script level and at the aggregation level are merged and can both be used in the aggregation scripts. If there are any conflicts, aggregation level params will win. This may be followed by another change detecting that case and throwing an exception to disallow such conflicts. * Disallow duplicate parameter names between scripted agg and script (#28819) If a scripted metric aggregation has aggregation params and script params which have the same name, throw an IllegalArgumentException when merging the parameter lists.
1 parent 77c0fba commit 66e6266

File tree

4 files changed

+146
-26
lines changed

4 files changed

+146
-26
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregationBuilder.java

+17-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
import org.elasticsearch.search.internal.SearchContext;
3838

3939
import java.io.IOException;
40+
import java.util.Collections;
4041
import java.util.Map;
4142
import java.util.Objects;
4243

@@ -198,20 +199,34 @@ protected ScriptedMetricAggregatorFactory doBuild(SearchContext context, Aggrega
198199
Builder subfactoriesBuilder) throws IOException {
199200

200201
QueryShardContext queryShardContext = context.getQueryShardContext();
202+
203+
// Extract params from scripts and pass them along to ScriptedMetricAggregatorFactory, since it won't have
204+
// access to them for the scripts it's given precompiled.
205+
201206
ExecutableScript.Factory executableInitScript;
207+
Map<String, Object> initScriptParams;
202208
if (initScript != null) {
203209
executableInitScript = queryShardContext.getScriptService().compile(initScript, ExecutableScript.AGGS_CONTEXT);
210+
initScriptParams = initScript.getParams();
204211
} else {
205212
executableInitScript = p -> null;
213+
initScriptParams = Collections.emptyMap();
206214
}
215+
207216
SearchScript.Factory searchMapScript = queryShardContext.getScriptService().compile(mapScript, SearchScript.AGGS_CONTEXT);
217+
Map<String, Object> mapScriptParams = mapScript.getParams();
218+
208219
ExecutableScript.Factory executableCombineScript;
220+
Map<String, Object> combineScriptParams;
209221
if (combineScript != null) {
210-
executableCombineScript =queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
222+
executableCombineScript = queryShardContext.getScriptService().compile(combineScript, ExecutableScript.AGGS_CONTEXT);
223+
combineScriptParams = combineScript.getParams();
211224
} else {
212225
executableCombineScript = p -> null;
226+
combineScriptParams = Collections.emptyMap();
213227
}
214-
return new ScriptedMetricAggregatorFactory(name, searchMapScript, executableInitScript, executableCombineScript, reduceScript,
228+
return new ScriptedMetricAggregatorFactory(name, searchMapScript, mapScriptParams, executableInitScript, initScriptParams,
229+
executableCombineScript, combineScriptParams, reduceScript,
215230
params, queryShardContext.lookup(), context, parent, subfactoriesBuilder, metaData);
216231
}
217232

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java

+35-15
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,35 @@
3535
import java.util.HashMap;
3636
import java.util.List;
3737
import java.util.Map;
38-
import java.util.function.Function;
3938

4039
public class ScriptedMetricAggregatorFactory extends AggregatorFactory<ScriptedMetricAggregatorFactory> {
4140

4241
private final SearchScript.Factory mapScript;
42+
private final Map<String, Object> mapScriptParams;
4343
private final ExecutableScript.Factory combineScript;
44+
private final Map<String, Object> combineScriptParams;
4445
private final Script reduceScript;
45-
private final Map<String, Object> params;
46+
private final Map<String, Object> aggParams;
4647
private final SearchLookup lookup;
4748
private final ExecutableScript.Factory initScript;
49+
private final Map<String, Object> initScriptParams;
4850

49-
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, ExecutableScript.Factory initScript,
50-
ExecutableScript.Factory combineScript, Script reduceScript, Map<String, Object> params,
51+
public ScriptedMetricAggregatorFactory(String name, SearchScript.Factory mapScript, Map<String, Object> mapScriptParams,
52+
ExecutableScript.Factory initScript, Map<String, Object> initScriptParams,
53+
ExecutableScript.Factory combineScript, Map<String, Object> combineScriptParams,
54+
Script reduceScript, Map<String, Object> aggParams,
5155
SearchLookup lookup, SearchContext context, AggregatorFactory<?> parent,
5256
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
5357
super(name, context, parent, subFactories, metaData);
5458
this.mapScript = mapScript;
59+
this.mapScriptParams = mapScriptParams;
5560
this.initScript = initScript;
61+
this.initScriptParams = initScriptParams;
5662
this.combineScript = combineScript;
63+
this.combineScriptParams = combineScriptParams;
5764
this.reduceScript = reduceScript;
5865
this.lookup = lookup;
59-
this.params = params;
66+
this.aggParams = aggParams;
6067
}
6168

6269
@Override
@@ -65,26 +72,26 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
6572
if (collectsFromSingleBucket == false) {
6673
return asMultiBucketAggregator(this, context, parent);
6774
}
68-
Map<String, Object> params = this.params;
69-
if (params != null) {
70-
params = deepCopyParams(params, context);
75+
Map<String, Object> aggParams = this.aggParams;
76+
if (aggParams != null) {
77+
aggParams = deepCopyParams(aggParams, context);
7178
} else {
72-
params = new HashMap<>();
79+
aggParams = new HashMap<>();
7380
}
74-
if (params.containsKey("_agg") == false) {
75-
params.put("_agg", new HashMap<String, Object>());
81+
if (aggParams.containsKey("_agg") == false) {
82+
aggParams.put("_agg", new HashMap<String, Object>());
7683
}
7784

78-
final ExecutableScript initScript = this.initScript.newInstance(params);
79-
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(params, lookup);
80-
final ExecutableScript combineScript = this.combineScript.newInstance(params);
85+
final ExecutableScript initScript = this.initScript.newInstance(mergeParams(aggParams, initScriptParams));
86+
final SearchScript.LeafFactory mapScript = this.mapScript.newFactory(mergeParams(aggParams, mapScriptParams), lookup);
87+
final ExecutableScript combineScript = this.combineScript.newInstance(mergeParams(aggParams, combineScriptParams));
8188

8289
final Script reduceScript = deepCopyScript(this.reduceScript, context);
8390
if (initScript != null) {
8491
initScript.run();
8592
}
8693
return new ScriptedMetricAggregator(name, mapScript,
87-
combineScript, reduceScript, params, context, parent,
94+
combineScript, reduceScript, aggParams, context, parent,
8895
pipelineAggregators, metaData);
8996
}
9097

@@ -128,5 +135,18 @@ private static <T> T deepCopyParams(T original, SearchContext context) {
128135
return clone;
129136
}
130137

138+
private static Map<String, Object> mergeParams(Map<String, Object> agg, Map<String, Object> script) {
139+
// Start with script params
140+
Map<String, Object> combined = new HashMap<>(script);
131141

142+
// Add in agg params, throwing an exception if any conflicts are detected
143+
for (Map.Entry<String, Object> aggEntry : agg.entrySet()) {
144+
if (combined.putIfAbsent(aggEntry.getKey(), aggEntry.getValue()) != null) {
145+
throw new IllegalArgumentException("Parameter name \"" + aggEntry.getKey() +
146+
"\" used in both aggregation and script parameters");
147+
}
148+
}
149+
150+
return combined;
151+
}
132152
}

server/src/test/java/org/elasticsearch/search/aggregations/metrics/ScriptedMetricIT.java

+23-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
package org.elasticsearch.search.aggregations.metrics;
2121

2222
import org.elasticsearch.action.index.IndexRequestBuilder;
23+
import org.elasticsearch.action.search.SearchPhaseExecutionException;
24+
import org.elasticsearch.action.search.SearchRequestBuilder;
2325
import org.elasticsearch.action.search.SearchResponse;
2426
import org.elasticsearch.common.bytes.BytesArray;
2527
import org.elasticsearch.common.settings.Settings;
@@ -62,6 +64,7 @@
6264
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
6365
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertSearchResponse;
6466
import static org.hamcrest.Matchers.allOf;
67+
import static org.hamcrest.Matchers.containsString;
6568
import static org.hamcrest.Matchers.equalTo;
6669
import static org.hamcrest.Matchers.greaterThan;
6770
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -322,11 +325,11 @@ public void testMap() {
322325
assertThat(numShardsRun, greaterThan(0));
323326
}
324327

325-
public void testMapWithParams() {
328+
public void testExplicitAggParam() {
326329
Map<String, Object> params = new HashMap<>();
327330
params.put("_agg", new ArrayList<>());
328331

329-
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params);
332+
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", Collections.emptyMap());
330333

331334
SearchResponse response = client().prepareSearch("idx")
332335
.setQuery(matchAllQuery())
@@ -361,17 +364,17 @@ public void testMapWithParams() {
361364
}
362365

363366
public void testMapWithParamsAndImplicitAggMap() {
364-
Map<String, Object> params = new HashMap<>();
365-
// don't put any _agg map in params
366-
params.put("param1", "12");
367-
params.put("param2", 1);
367+
// Split the params up between the script and the aggregation.
368+
// Don't put any _agg map in params.
369+
Map<String, Object> scriptParams = Collections.singletonMap("param1", "12");
370+
Map<String, Object> aggregationParams = Collections.singletonMap("param2", 1);
368371

369372
// The _agg hashmap will be available even if not declared in the params map
370-
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", params);
373+
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg[param1] = param2", scriptParams);
371374

372375
SearchResponse response = client().prepareSearch("idx")
373376
.setQuery(matchAllQuery())
374-
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript))
377+
.addAggregation(scriptedMetric("scripted").params(aggregationParams).mapScript(mapScript))
375378
.get();
376379
assertSearchResponse(response);
377380
assertThat(response.getHits().getTotalHits(), equalTo(numDocs));
@@ -1001,4 +1004,16 @@ public void testDontCacheScripts() throws Exception {
10011004
assertThat(client().admin().indices().prepareStats("cache_test_idx").setRequestCache(true).get().getTotal().getRequestCache()
10021005
.getMissCount(), equalTo(0L));
10031006
}
1007+
1008+
public void testConflictingAggAndScriptParams() {
1009+
Map<String, Object> params = Collections.singletonMap("param1", "12");
1010+
Script mapScript = new Script(ScriptType.INLINE, CustomScriptPlugin.NAME, "_agg.add(1)", params);
1011+
1012+
SearchRequestBuilder builder = client().prepareSearch("idx")
1013+
.setQuery(matchAllQuery())
1014+
.addAggregation(scriptedMetric("scripted").params(params).mapScript(mapScript));
1015+
1016+
SearchPhaseExecutionException ex = expectThrows(SearchPhaseExecutionException.class, builder::get);
1017+
assertThat(ex.getCause().getMessage(), containsString("Parameter name \"param1\" used in both aggregation and script parameters"));
1018+
}
10041019
}

server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java

+71-1
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,16 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
6464
Collections.emptyMap());
6565
private static final Script COMBINE_SCRIPT_SCORE = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptScore",
6666
Collections.emptyMap());
67-
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
6867

68+
private static final Script INIT_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptParams",
69+
Collections.singletonMap("initialValue", 24));
70+
private static final Script MAP_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptParams",
71+
Collections.singletonMap("itemValue", 12));
72+
private static final Script COMBINE_SCRIPT_PARAMS = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptParams",
73+
Collections.singletonMap("divisor", 4));
74+
private static final String CONFLICTING_PARAM_NAME = "initialValue";
75+
76+
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
6977

7078
@BeforeClass
7179
@SuppressWarnings("unchecked")
@@ -99,6 +107,26 @@ public static void initMockScripts() {
99107
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
100108
return ((List<Double>) agg.get("collector")).stream().mapToDouble(Double::doubleValue).sum();
101109
});
110+
111+
SCRIPTS.put("initScriptParams", params -> {
112+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
113+
Integer initialValue = (Integer)params.get("initialValue");
114+
ArrayList<Integer> collector = new ArrayList();
115+
collector.add(initialValue);
116+
agg.put("collector", collector);
117+
return agg;
118+
});
119+
SCRIPTS.put("mapScriptParams", params -> {
120+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
121+
Integer itemValue = (Integer) params.get("itemValue");
122+
((List<Integer>) agg.get("collector")).add(itemValue);
123+
return agg;
124+
});
125+
SCRIPTS.put("combineScriptParams", params -> {
126+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
127+
int divisor = ((Integer) params.get("divisor"));
128+
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
129+
});
102130
}
103131

104132
@SuppressWarnings("unchecked")
@@ -187,6 +215,48 @@ public void testScriptedMetricWithCombineAccessesScores() throws IOException {
187215
}
188216
}
189217

218+
public void testScriptParamsPassedThrough() throws IOException {
219+
try (Directory directory = newDirectory()) {
220+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
221+
for (int i = 0; i < 100; i++) {
222+
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
223+
}
224+
}
225+
226+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
227+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
228+
aggregationBuilder.initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).combineScript(COMBINE_SCRIPT_PARAMS);
229+
ScriptedMetric scriptedMetric = search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder);
230+
231+
// The result value depends on the script params.
232+
assertEquals(306, scriptedMetric.aggregation());
233+
}
234+
}
235+
}
236+
237+
public void testConflictingAggAndScriptParams() throws IOException {
238+
try (Directory directory = newDirectory()) {
239+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
240+
for (int i = 0; i < 100; i++) {
241+
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
242+
}
243+
}
244+
245+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
246+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
247+
Map<String, Object> aggParams = Collections.singletonMap(CONFLICTING_PARAM_NAME, "blah");
248+
aggregationBuilder.params(aggParams).initScript(INIT_SCRIPT_PARAMS).mapScript(MAP_SCRIPT_PARAMS).
249+
combineScript(COMBINE_SCRIPT_PARAMS);
250+
251+
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
252+
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
253+
);
254+
assertEquals("Parameter name \"" + CONFLICTING_PARAM_NAME + "\" used in both aggregation and script parameters",
255+
ex.getMessage());
256+
}
257+
}
258+
}
259+
190260
/**
191261
* We cannot use Mockito for mocking QueryShardContext in this case because
192262
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)

0 commit comments

Comments
 (0)