Skip to content

Commit 85c26d6

Browse files
rationulls1monw
authored andcommitted
Call ensureNoSelfReferences() on _agg state variable after scripted metric agg script executions (#31044)
Previously this was called for the combine script only. This change checks for self references for init, map, and reduce scripts as well, and adds unit test coverage for the init, map, and combine cases.
1 parent bd5c1a4 commit 85c26d6

File tree

12 files changed

+118
-18
lines changed

12 files changed

+118
-18
lines changed

modules/lang-mustache/src/main/java/org/elasticsearch/script/mustache/CustomReflectionObjectHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public Iterator<Object> iterator() {
157157

158158
@Override
159159
public String stringify(Object object) {
160-
CollectionUtils.ensureNoSelfReferences(object);
160+
CollectionUtils.ensureNoSelfReferences(object, "CustomReflectionObjectHandler stringify");
161161
return super.stringify(object);
162162
}
163163
}

server/src/main/java/org/elasticsearch/common/util/CollectionUtils.java

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import java.util.IdentityHashMap;
3030
import java.util.LinkedList;
3131
import java.util.List;
32+
import java.util.Locale;
3233
import java.util.Map;
3334
import java.util.Objects;
3435
import java.util.RandomAccess;
@@ -40,6 +41,7 @@
4041
import org.apache.lucene.util.BytesRefBuilder;
4142
import org.apache.lucene.util.InPlaceMergeSorter;
4243
import org.apache.lucene.util.IntroSorter;
44+
import org.elasticsearch.common.Strings;
4345

4446
/** Collections-related utility methods. */
4547
public class CollectionUtils {
@@ -225,10 +227,17 @@ public static int[] toArray(Collection<Integer> ints) {
225227
return ints.stream().mapToInt(s -> s).toArray();
226228
}
227229

228-
public static void ensureNoSelfReferences(Object value) {
230+
/**
231+
* Deeply inspects a Map, Iterable, or Object array looking for references back to itself.
232+
* @throws IllegalArgumentException if a self-reference is found
233+
* @param value The object to evaluate looking for self references
234+
* @param messageHint A string to be included in the exception message if the call fails, to provide
235+
* more context to the handler of the exception
236+
*/
237+
public static void ensureNoSelfReferences(Object value, String messageHint) {
229238
Iterable<?> it = convert(value);
230239
if (it != null) {
231-
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()));
240+
ensureNoSelfReferences(it, value, Collections.newSetFromMap(new IdentityHashMap<>()), messageHint);
232241
}
233242
}
234243

@@ -247,13 +256,15 @@ private static Iterable<?> convert(Object value) {
247256
}
248257
}
249258

250-
private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors) {
259+
private static void ensureNoSelfReferences(final Iterable<?> value, Object originalReference, final Set<Object> ancestors,
260+
String messageHint) {
251261
if (value != null) {
252262
if (ancestors.add(originalReference) == false) {
253-
throw new IllegalArgumentException("Iterable object is self-referencing itself");
263+
String suffix = Strings.isNullOrEmpty(messageHint) ? "" : String.format(Locale.ROOT, " (%s)", messageHint);
264+
throw new IllegalArgumentException("Iterable object is self-referencing itself" + suffix);
254265
}
255266
for (Object o : value) {
256-
ensureNoSelfReferences(convert(o), o, ancestors);
267+
ensureNoSelfReferences(convert(o), o, ancestors, messageHint);
257268
}
258269
ancestors.remove(originalReference);
259270
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/InternalScriptedMetric.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.elasticsearch.common.io.stream.StreamInput;
2323
import org.elasticsearch.common.io.stream.StreamOutput;
24+
import org.elasticsearch.common.util.CollectionUtils;
2425
import org.elasticsearch.common.xcontent.XContentBuilder;
2526
import org.elasticsearch.script.ExecutableScript;
2627
import org.elasticsearch.script.Script;
@@ -97,7 +98,11 @@ public InternalAggregation doReduce(List<InternalAggregation> aggregations, Redu
9798
ExecutableScript.Factory factory = reduceContext.scriptService().compile(
9899
firstAggregation.reduceScript, ExecutableScript.AGGS_CONTEXT);
99100
ExecutableScript script = factory.newInstance(vars);
100-
aggregation = Collections.singletonList(script.run());
101+
102+
Object scriptResult = script.run();
103+
CollectionUtils.ensureNoSelfReferences(scriptResult, "reduce script");
104+
105+
aggregation = Collections.singletonList(scriptResult);
101106
} else if (reduceContext.isFinalReduce()) {
102107
aggregation = Collections.singletonList(aggregationObjects);
103108
} else {

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregator.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ public void collect(int doc, long bucket) throws IOException {
6969
assert bucket == 0 : bucket;
7070
leafMapScript.setDocument(doc);
7171
leafMapScript.run();
72+
CollectionUtils.ensureNoSelfReferences(params, "Scripted metric aggs map script");
7273
}
7374
};
7475
}
@@ -78,7 +79,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) {
7879
Object aggregation;
7980
if (combineScript != null) {
8081
aggregation = combineScript.run();
81-
CollectionUtils.ensureNoSelfReferences(aggregation);
82+
CollectionUtils.ensureNoSelfReferences(aggregation, "Scripted metric aggs combine script");
8283
} else {
8384
aggregation = params.get("_agg");
8485
}

server/src/main/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorFactory.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.search.aggregations.metrics.scripted;
2121

22+
import org.elasticsearch.common.util.CollectionUtils;
2223
import org.elasticsearch.script.ExecutableScript;
2324
import org.elasticsearch.script.Script;
2425
import org.elasticsearch.script.SearchScript;
@@ -89,6 +90,7 @@ public Aggregator createInternal(Aggregator parent, boolean collectsFromSingleBu
8990
final Script reduceScript = deepCopyScript(this.reduceScript, context);
9091
if (initScript != null) {
9192
initScript.run();
93+
CollectionUtils.ensureNoSelfReferences(aggParams.get("_agg"), "Scripted metric aggs init script");
9294
}
9395
return new ScriptedMetricAggregator(name, mapScript,
9496
combineScript, reduceScript, aggParams, context, parent,

server/src/main/java/org/elasticsearch/search/aggregations/support/ValuesSource.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ public boolean advanceExact(int doc) throws IOException {
458458
final BytesRef value = bytesValues.nextValue();
459459
script.setNextAggregationValue(value.utf8ToString());
460460
Object run = script.run();
461-
CollectionUtils.ensureNoSelfReferences(run);
461+
CollectionUtils.ensureNoSelfReferences(run, "ValuesSource.BytesValues script");
462462
values[i].copyChars(run.toString());
463463
}
464464
sort();

server/src/main/java/org/elasticsearch/search/aggregations/support/values/ScriptBytesValues.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ private void set(int i, Object o) {
4545
if (o == null) {
4646
values[i].clear();
4747
} else {
48-
CollectionUtils.ensureNoSelfReferences(o);
48+
CollectionUtils.ensureNoSelfReferences(o, "ScriptBytesValues value");
4949
values[i].copyChars(o.toString());
5050
}
5151
}

server/src/main/java/org/elasticsearch/search/fetch/subphase/ScriptFieldsFetchSubPhase.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public void hitsExecute(SearchContext context, SearchHit[] hits) throws IOExcept
6565
final Object value;
6666
try {
6767
value = leafScripts[i].run();
68-
CollectionUtils.ensureNoSelfReferences(value);
68+
CollectionUtils.ensureNoSelfReferences(value, "ScriptFieldsFetchSubPhase leaf script " + i);
6969
} catch (RuntimeException e) {
7070
if (scriptFields.get(i).ignoreException()) {
7171
continue;

server/src/main/java/org/elasticsearch/search/sort/ScriptSortBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ public boolean advanceExact(int doc) throws IOException {
343343
@Override
344344
public BytesRef binaryValue() {
345345
final Object run = leafScript.run();
346-
CollectionUtils.ensureNoSelfReferences(run);
346+
CollectionUtils.ensureNoSelfReferences(run, "ScriptSortBuilder leaf script");
347347
spare.copyChars(run.toString());
348348
return spare.get();
349349
}

server/src/test/java/org/elasticsearch/common/util/CollectionUtilsTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,14 @@ public void testPerfectPartition() {
183183
}
184184

185185
public void testEnsureNoSelfReferences() {
186-
CollectionUtils.ensureNoSelfReferences(emptyMap());
187-
CollectionUtils.ensureNoSelfReferences(null);
186+
CollectionUtils.ensureNoSelfReferences(emptyMap(), "test with empty map");
187+
CollectionUtils.ensureNoSelfReferences(null, "test with null");
188188

189189
Map<String, Object> map = new HashMap<>();
190190
map.put("field", map);
191191

192-
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> CollectionUtils.ensureNoSelfReferences(map));
193-
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself"));
192+
IllegalArgumentException e = expectThrows(IllegalArgumentException.class,
193+
() -> CollectionUtils.ensureNoSelfReferences(map, "test with self ref"));
194+
assertThat(e.getMessage(), containsString("Iterable object is self-referencing itself (test with self ref)"));
194195
}
195196
}

server/src/test/java/org/elasticsearch/common/xcontent/BaseXContentTestCase.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,8 +843,8 @@ public void testEnsureNotNull() {
843843
}
844844

845845
public void testEnsureNoSelfReferences() throws IOException {
846-
CollectionUtils.ensureNoSelfReferences(emptyMap());
847-
CollectionUtils.ensureNoSelfReferences(null);
846+
builder().map(emptyMap());
847+
builder().map(null);
848848

849849
Map<String, Object> map = new HashMap<>();
850850
map.put("field", map);

server/src/test/java/org/elasticsearch/search/aggregations/metrics/scripted/ScriptedMetricAggregatorTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,13 @@ public class ScriptedMetricAggregatorTests extends AggregatorTestCase {
7373
Collections.singletonMap("divisor", 4));
7474
private static final String CONFLICTING_PARAM_NAME = "initialValue";
7575

76+
private static final Script INIT_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "initScriptSelfRef",
77+
Collections.emptyMap());
78+
private static final Script MAP_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "mapScriptSelfRef",
79+
Collections.emptyMap());
80+
private static final Script COMBINE_SCRIPT_SELF_REF = new Script(ScriptType.INLINE, MockScriptEngine.NAME, "combineScriptSelfRef",
81+
Collections.emptyMap());
82+
7683
private static final Map<String, Function<Map<String, Object>, Object>> SCRIPTS = new HashMap<>();
7784

7885
@BeforeClass
@@ -127,6 +134,25 @@ public static void initMockScripts() {
127134
int divisor = ((Integer) params.get("divisor"));
128135
return ((List<Integer>) agg.get("collector")).stream().mapToInt(Integer::intValue).map(i -> i / divisor).sum();
129136
});
137+
138+
SCRIPTS.put("initScriptSelfRef", params -> {
139+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
140+
agg.put("collector", new ArrayList<Integer>());
141+
agg.put("selfRef", agg);
142+
return agg;
143+
});
144+
145+
SCRIPTS.put("mapScriptSelfRef", params -> {
146+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
147+
agg.put("selfRef", agg);
148+
return agg;
149+
});
150+
151+
SCRIPTS.put("combineScriptSelfRef", params -> {
152+
Map<String, Object> agg = (Map<String, Object>) params.get("_agg");
153+
agg.put("selfRef", agg);
154+
return agg;
155+
});
130156
}
131157

132158
@SuppressWarnings("unchecked")
@@ -257,6 +283,60 @@ public void testConflictingAggAndScriptParams() throws IOException {
257283
}
258284
}
259285

286+
public void testSelfReferencingAggStateAfterInit() throws IOException {
287+
try (Directory directory = newDirectory()) {
288+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
289+
// No need to add docs for this test
290+
}
291+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
292+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
293+
aggregationBuilder.initScript(INIT_SCRIPT_SELF_REF).mapScript(MAP_SCRIPT);
294+
295+
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
296+
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
297+
);
298+
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs init script)", ex.getMessage());
299+
}
300+
}
301+
}
302+
303+
public void testSelfReferencingAggStateAfterMap() throws IOException {
304+
try (Directory directory = newDirectory()) {
305+
Integer numDocs = randomInt(100);
306+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
307+
for (int i = 0; i < numDocs; i++) {
308+
indexWriter.addDocument(singleton(new SortedNumericDocValuesField("number", i)));
309+
}
310+
}
311+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
312+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
313+
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT_SELF_REF);
314+
315+
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
316+
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
317+
);
318+
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs map script)", ex.getMessage());
319+
}
320+
}
321+
}
322+
323+
public void testSelfReferencingAggStateAfterCombine() throws IOException {
324+
try (Directory directory = newDirectory()) {
325+
try (RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
326+
// No need to add docs for this test
327+
}
328+
try (IndexReader indexReader = DirectoryReader.open(directory)) {
329+
ScriptedMetricAggregationBuilder aggregationBuilder = new ScriptedMetricAggregationBuilder(AGG_NAME);
330+
aggregationBuilder.initScript(INIT_SCRIPT).mapScript(MAP_SCRIPT).combineScript(COMBINE_SCRIPT_SELF_REF);
331+
332+
IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () ->
333+
search(newSearcher(indexReader, true, true), new MatchAllDocsQuery(), aggregationBuilder)
334+
);
335+
assertEquals("Iterable object is self-referencing itself (Scripted metric aggs combine script)", ex.getMessage());
336+
}
337+
}
338+
}
339+
260340
/**
261341
* We cannot use Mockito for mocking QueryShardContext in this case because
262342
* script-related methods (e.g. QueryShardContext#getLazyExecutableScript)

0 commit comments

Comments
 (0)