Skip to content

Commit 02ff060

Browse files
authored
Support BucketScript paths of type string and array. (#44694)
1 parent 69ada4d commit 02ff060

File tree

3 files changed

+233
-1
lines changed

3 files changed

+233
-1
lines changed

server/src/main/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptPipelineAggregationBuilder.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
3131

3232
import java.io.IOException;
33+
import java.util.Collections;
3334
import java.util.HashMap;
3435
import java.util.Locale;
3536
import java.util.Map;
@@ -59,7 +60,10 @@ public class BucketScriptPipelineAggregationBuilder extends AbstractPipelineAggr
5960
false,
6061
o -> new BucketScriptPipelineAggregationBuilder(name, (Map<String, String>) o[0], (Script) o[1]));
6162

62-
parser.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
63+
parser.declareField(ConstructingObjectParser.constructorArg()
64+
, BucketScriptPipelineAggregationBuilder::extractBucketPath
65+
, BUCKETS_PATH_FIELD
66+
, ObjectParser.ValueType.OBJECT_ARRAY_OR_STRING);
6367
parser.declareField(ConstructingObjectParser.constructorArg(),
6468
(p, c) -> Script.parse(p), Script.SCRIPT_PARSE_FIELD, ObjectParser.ValueType.OBJECT_OR_STRING);
6569

@@ -112,6 +116,27 @@ protected void doWriteTo(StreamOutput out) throws IOException {
112116
gapPolicy.writeTo(out);
113117
}
114118

119+
private static Map<String, String> extractBucketPath(XContentParser parser) throws IOException {
120+
XContentParser.Token token = parser.currentToken();
121+
if (token == XContentParser.Token.VALUE_STRING) {
122+
// input is a string, name of the path set to '_value'.
123+
// This is a bit odd as there is not constructor for it
124+
return Collections.singletonMap("_value", parser.text());
125+
} else if (token == XContentParser.Token.START_ARRAY) {
126+
// input is an array, name of the path set to '_value' + position
127+
Map<String, String> bucketsPathsMap = new HashMap<>();
128+
int i =0;
129+
while ((parser.nextToken()) != XContentParser.Token.END_ARRAY) {
130+
String path = parser.text();
131+
bucketsPathsMap.put("_value" + i++, path);
132+
}
133+
return bucketsPathsMap;
134+
} else {
135+
// input is an object, it should contain name / value pairs
136+
return parser.mapStrings();
137+
}
138+
}
139+
115140
private static Map<String, String> convertToBucketsPathMap(String[] bucketsPaths) {
116141
Map<String, String> bucketsPathsMap = new HashMap<>();
117142
for (int i = 0; i < bucketsPaths.length; i++) {

server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptIT.java

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.action.search.SearchResponse;
2424
import org.elasticsearch.common.bytes.BytesArray;
2525
import org.elasticsearch.common.xcontent.XContentBuilder;
26+
import org.elasticsearch.common.xcontent.XContentFactory;
2627
import org.elasticsearch.common.xcontent.XContentType;
2728
import org.elasticsearch.plugins.Plugin;
2829
import org.elasticsearch.script.MockScriptPlugin;
@@ -117,6 +118,11 @@ protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
117118
return value0 + value1 + value2;
118119
});
119120

121+
scripts.put("single_input", vars -> {
122+
double value = (double) vars.get("_value");
123+
return value;
124+
});
125+
120126
scripts.put("return null", vars -> null);
121127

122128
return scripts;
@@ -628,4 +634,159 @@ public void testPartiallyUnmapped() throws Exception {
628634
}
629635
}
630636
}
637+
638+
public void testSingleBucketPathAgg() throws Exception {
639+
XContentBuilder content = XContentFactory.jsonBuilder()
640+
.startObject()
641+
.field("buckets_path", "field2Sum")
642+
.startObject("script")
643+
.field("source", "single_input")
644+
.field("lang", CustomScriptPlugin.NAME)
645+
.endObject()
646+
.endObject();
647+
BucketScriptPipelineAggregationBuilder bucketScriptAgg =
648+
BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
649+
650+
SearchResponse response = client()
651+
.prepareSearch("idx", "idx_unmapped")
652+
.addAggregation(
653+
histogram("histo")
654+
.field(FIELD_1_NAME)
655+
.interval(interval)
656+
.subAggregation(sum("field2Sum").field(FIELD_2_NAME))
657+
.subAggregation(bucketScriptAgg)).get();
658+
659+
assertSearchResponse(response);
660+
661+
Histogram histo = response.getAggregations().get("histo");
662+
assertThat(histo, notNullValue());
663+
assertThat(histo.getName(), equalTo("histo"));
664+
List<? extends Histogram.Bucket> buckets = histo.getBuckets();
665+
666+
for (int i = 0; i < buckets.size(); ++i) {
667+
Histogram.Bucket bucket = buckets.get(i);
668+
if (bucket.getDocCount() == 0) {
669+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
670+
assertThat(seriesArithmetic, nullValue());
671+
} else {
672+
Sum field2Sum = bucket.getAggregations().get("field2Sum");
673+
assertThat(field2Sum, notNullValue());
674+
double field2SumValue = field2Sum.getValue();
675+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
676+
assertThat(seriesArithmetic, notNullValue());
677+
double seriesArithmeticValue = seriesArithmetic.value();
678+
assertThat(seriesArithmeticValue, equalTo(field2SumValue));
679+
}
680+
}
681+
}
682+
683+
public void testArrayBucketPathAgg() throws Exception {
684+
XContentBuilder content = XContentFactory.jsonBuilder()
685+
.startObject()
686+
.array("buckets_path", "field2Sum", "field3Sum", "field4Sum")
687+
.startObject("script")
688+
.field("source", "_value0 + _value1 + _value2")
689+
.field("lang", CustomScriptPlugin.NAME)
690+
.endObject()
691+
.endObject();
692+
BucketScriptPipelineAggregationBuilder bucketScriptAgg =
693+
BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
694+
695+
SearchResponse response = client()
696+
.prepareSearch("idx", "idx_unmapped")
697+
.addAggregation(
698+
histogram("histo")
699+
.field(FIELD_1_NAME)
700+
.interval(interval)
701+
.subAggregation(sum("field2Sum").field(FIELD_2_NAME))
702+
.subAggregation(sum("field3Sum").field(FIELD_3_NAME))
703+
.subAggregation(sum("field4Sum").field(FIELD_4_NAME))
704+
.subAggregation(bucketScriptAgg)).get();
705+
706+
assertSearchResponse(response);
707+
708+
Histogram histo = response.getAggregations().get("histo");
709+
assertThat(histo, notNullValue());
710+
assertThat(histo.getName(), equalTo("histo"));
711+
List<? extends Histogram.Bucket> buckets = histo.getBuckets();
712+
713+
for (int i = 0; i < buckets.size(); ++i) {
714+
Histogram.Bucket bucket = buckets.get(i);
715+
if (bucket.getDocCount() == 0) {
716+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
717+
assertThat(seriesArithmetic, nullValue());
718+
} else {
719+
Sum field2Sum = bucket.getAggregations().get("field2Sum");
720+
assertThat(field2Sum, notNullValue());
721+
double field2SumValue = field2Sum.getValue();
722+
Sum field3Sum = bucket.getAggregations().get("field3Sum");
723+
assertThat(field3Sum, notNullValue());
724+
double field3SumValue = field3Sum.getValue();
725+
Sum field4Sum = bucket.getAggregations().get("field4Sum");
726+
assertThat(field4Sum, notNullValue());
727+
double field4SumValue = field4Sum.getValue();
728+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
729+
assertThat(seriesArithmetic, notNullValue());
730+
double seriesArithmeticValue = seriesArithmetic.value();
731+
assertThat(seriesArithmeticValue, equalTo(field2SumValue + field3SumValue + field4SumValue));
732+
}
733+
}
734+
}
735+
736+
public void testObjectBucketPathAgg() throws Exception {
737+
XContentBuilder content = XContentFactory.jsonBuilder()
738+
.startObject()
739+
.startObject("buckets_path")
740+
.field("_value0", "field2Sum")
741+
.field("_value1", "field3Sum")
742+
.field("_value2", "field4Sum")
743+
.endObject()
744+
.startObject("script")
745+
.field("source", "_value0 + _value1 + _value2")
746+
.field("lang", CustomScriptPlugin.NAME)
747+
.endObject()
748+
.endObject();
749+
BucketScriptPipelineAggregationBuilder bucketScriptAgg =
750+
BucketScriptPipelineAggregationBuilder.parse("seriesArithmetic", createParser(content));
751+
752+
SearchResponse response = client()
753+
.prepareSearch("idx", "idx_unmapped")
754+
.addAggregation(
755+
histogram("histo")
756+
.field(FIELD_1_NAME)
757+
.interval(interval)
758+
.subAggregation(sum("field2Sum").field(FIELD_2_NAME))
759+
.subAggregation(sum("field3Sum").field(FIELD_3_NAME))
760+
.subAggregation(sum("field4Sum").field(FIELD_4_NAME))
761+
.subAggregation(bucketScriptAgg)).get();
762+
763+
assertSearchResponse(response);
764+
765+
Histogram histo = response.getAggregations().get("histo");
766+
assertThat(histo, notNullValue());
767+
assertThat(histo.getName(), equalTo("histo"));
768+
List<? extends Histogram.Bucket> buckets = histo.getBuckets();
769+
770+
for (int i = 0; i < buckets.size(); ++i) {
771+
Histogram.Bucket bucket = buckets.get(i);
772+
if (bucket.getDocCount() == 0) {
773+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
774+
assertThat(seriesArithmetic, nullValue());
775+
} else {
776+
Sum field2Sum = bucket.getAggregations().get("field2Sum");
777+
assertThat(field2Sum, notNullValue());
778+
double field2SumValue = field2Sum.getValue();
779+
Sum field3Sum = bucket.getAggregations().get("field3Sum");
780+
assertThat(field3Sum, notNullValue());
781+
double field3SumValue = field3Sum.getValue();
782+
Sum field4Sum = bucket.getAggregations().get("field4Sum");
783+
assertThat(field4Sum, notNullValue());
784+
double field4SumValue = field4Sum.getValue();
785+
SimpleValue seriesArithmetic = bucket.getAggregations().get("seriesArithmetic");
786+
assertThat(seriesArithmetic, notNullValue());
787+
double seriesArithmeticValue = seriesArithmetic.value();
788+
assertThat(seriesArithmeticValue, equalTo(field2SumValue + field3SumValue + field4SumValue));
789+
}
790+
}
791+
}
631792
}

server/src/test/java/org/elasticsearch/search/aggregations/pipeline/BucketScriptTests.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,14 @@
1919

2020
package org.elasticsearch.search.aggregations.pipeline;
2121

22+
import org.elasticsearch.common.xcontent.XContentBuilder;
23+
import org.elasticsearch.common.xcontent.XContentFactory;
2224
import org.elasticsearch.script.Script;
2325
import org.elasticsearch.script.ScriptType;
2426
import org.elasticsearch.search.aggregations.BasePipelineAggregationTestCase;
2527
import org.elasticsearch.search.aggregations.pipeline.BucketHelpers.GapPolicy;
2628

29+
import java.io.IOException;
2730
import java.util.HashMap;
2831
import java.util.Map;
2932

@@ -59,4 +62,47 @@ protected BucketScriptPipelineAggregationBuilder createTestAggregatorFactory() {
5962
return factory;
6063
}
6164

65+
public void testParseBucketPath() throws IOException {
66+
XContentBuilder content = XContentFactory.jsonBuilder()
67+
.startObject()
68+
.field("buckets_path", "_count")
69+
.startObject("script")
70+
.field("source", "value")
71+
.field("lang", "expression")
72+
.endObject()
73+
.endObject();
74+
BucketScriptPipelineAggregationBuilder builder1 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
75+
assertEquals(builder1.getBucketsPaths().length , 1);
76+
assertEquals(builder1.getBucketsPaths()[0], "_count");
77+
78+
content = XContentFactory.jsonBuilder()
79+
.startObject()
80+
.startObject("buckets_path")
81+
.field("path1", "_count1")
82+
.field("path2", "_count2")
83+
.endObject()
84+
.startObject("script")
85+
.field("source", "value")
86+
.field("lang", "expression")
87+
.endObject()
88+
.endObject();
89+
BucketScriptPipelineAggregationBuilder builder2 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
90+
assertEquals(builder2.getBucketsPaths().length , 2);
91+
assertEquals(builder2.getBucketsPaths()[0], "_count1");
92+
assertEquals(builder2.getBucketsPaths()[1], "_count2");
93+
94+
content = XContentFactory.jsonBuilder()
95+
.startObject()
96+
.array("buckets_path","_count1", "_count2")
97+
.startObject("script")
98+
.field("source", "value")
99+
.field("lang", "expression")
100+
.endObject()
101+
.endObject();
102+
BucketScriptPipelineAggregationBuilder builder3 = BucketScriptPipelineAggregationBuilder.parse("count", createParser(content));
103+
assertEquals(builder3.getBucketsPaths().length , 2);
104+
assertEquals(builder3.getBucketsPaths()[0], "_count1");
105+
assertEquals(builder3.getBucketsPaths()[1], "_count2");
106+
}
107+
62108
}

0 commit comments

Comments
 (0)