Skip to content

Commit c5443f7

Browse files
authored
Add Inference Pipeline aggregation to HLRC (#59086) (#59250)
Adds InferencePipelineAggregationBuilder to the HLRC duplicating the server side classes
1 parent d56fc72 commit c5443f7

File tree

12 files changed

+747
-1
lines changed

12 files changed

+747
-1
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/RestHighLevelClient.java

+3
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
import org.elasticsearch.action.support.master.AcknowledgedResponse;
5555
import org.elasticsearch.action.update.UpdateRequest;
5656
import org.elasticsearch.action.update.UpdateResponse;
57+
import org.elasticsearch.client.analytics.InferencePipelineAggregationBuilder;
58+
import org.elasticsearch.client.analytics.ParsedInference;
5759
import org.elasticsearch.client.analytics.ParsedStringStats;
5860
import org.elasticsearch.client.analytics.ParsedTopMetrics;
5961
import org.elasticsearch.client.analytics.StringStatsAggregationBuilder;
@@ -1957,6 +1959,7 @@ static List<NamedXContentRegistry.Entry> getDefaultNamedXContents() {
19571959
map.put(CompositeAggregationBuilder.NAME, (p, c) -> ParsedComposite.fromXContent(p, (String) c));
19581960
map.put(StringStatsAggregationBuilder.NAME, (p, c) -> ParsedStringStats.PARSER.parse(p, (String) c));
19591961
map.put(TopMetricsAggregationBuilder.NAME, (p, c) -> ParsedTopMetrics.PARSER.parse(p, (String) c));
1962+
map.put(InferencePipelineAggregationBuilder.NAME, (p, c) -> ParsedInference.fromXContent(p, (String ) (c)));
19601963
List<NamedXContentRegistry.Entry> entries = map.entrySet().stream()
19611964
.map(entry -> new NamedXContentRegistry.Entry(Aggregation.class, new ParseField(entry.getKey()), entry.getValue()))
19621965
.collect(Collectors.toList());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.analytics;
21+
22+
import org.elasticsearch.client.ml.inference.trainedmodel.InferenceConfig;
23+
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.io.stream.StreamOutput;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.XContentBuilder;
27+
import org.elasticsearch.common.xcontent.XContentParser;
28+
import org.elasticsearch.search.aggregations.PipelineAggregationBuilder;
29+
import org.elasticsearch.search.aggregations.pipeline.AbstractPipelineAggregationBuilder;
30+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
31+
import org.elasticsearch.search.builder.SearchSourceBuilder;
32+
33+
import java.io.IOException;
34+
import java.util.Map;
35+
import java.util.Objects;
36+
import java.util.TreeMap;
37+
38+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
39+
40+
/**
41+
* For building inference pipeline aggregations
42+
*
43+
* NOTE: This extends {@linkplain AbstractPipelineAggregationBuilder} for compatibility
44+
* with {@link SearchSourceBuilder#aggregation(PipelineAggregationBuilder)} but it
45+
* doesn't support any "server" side things like {@linkplain #doWriteTo(StreamOutput)}
46+
* or {@linkplain #createInternal(Map)}
47+
*/
48+
public class InferencePipelineAggregationBuilder extends AbstractPipelineAggregationBuilder<InferencePipelineAggregationBuilder> {
49+
50+
public static String NAME = "inference";
51+
52+
public static final ParseField MODEL_ID = new ParseField("model_id");
53+
private static final ParseField INFERENCE_CONFIG = new ParseField("inference_config");
54+
55+
56+
@SuppressWarnings("unchecked")
57+
private static final ConstructingObjectParser<InferencePipelineAggregationBuilder, String> PARSER = new ConstructingObjectParser<>(
58+
NAME, false,
59+
(args, name) -> new InferencePipelineAggregationBuilder(name, (String)args[0], (Map<String, String>) args[1])
60+
);
61+
62+
static {
63+
PARSER.declareString(constructorArg(), MODEL_ID);
64+
PARSER.declareObject(constructorArg(), (p, c) -> p.mapStrings(), BUCKETS_PATH_FIELD);
65+
PARSER.declareNamedObject(InferencePipelineAggregationBuilder::setInferenceConfig,
66+
(p, c, n) -> p.namedObject(InferenceConfig.class, n, c), INFERENCE_CONFIG);
67+
}
68+
69+
private final Map<String, String> bucketPathMap;
70+
private final String modelId;
71+
private InferenceConfig inferenceConfig;
72+
73+
public static InferencePipelineAggregationBuilder parse(String pipelineAggregatorName,
74+
XContentParser parser) {
75+
return PARSER.apply(parser, pipelineAggregatorName);
76+
}
77+
78+
public InferencePipelineAggregationBuilder(String name, String modelId, Map<String, String> bucketsPath) {
79+
super(name, NAME, new TreeMap<>(bucketsPath).values().toArray(new String[] {}));
80+
this.modelId = modelId;
81+
this.bucketPathMap = bucketsPath;
82+
}
83+
84+
public void setInferenceConfig(InferenceConfig inferenceConfig) {
85+
this.inferenceConfig = inferenceConfig;
86+
}
87+
88+
@Override
89+
protected void validate(ValidationContext context) {
90+
// validation occurs on the server
91+
}
92+
93+
@Override
94+
protected void doWriteTo(StreamOutput out) {
95+
throw new UnsupportedOperationException();
96+
}
97+
98+
@Override
99+
protected PipelineAggregator createInternal(Map<String, Object> metaData) {
100+
throw new UnsupportedOperationException();
101+
}
102+
103+
@Override
104+
protected boolean overrideBucketsPath() {
105+
return true;
106+
}
107+
108+
@Override
109+
protected XContentBuilder internalXContent(XContentBuilder builder, Params params) throws IOException {
110+
builder.field(MODEL_ID.getPreferredName(), modelId);
111+
builder.field(BUCKETS_PATH_FIELD.getPreferredName(), bucketPathMap);
112+
if (inferenceConfig != null) {
113+
builder.startObject(INFERENCE_CONFIG.getPreferredName());
114+
builder.field(inferenceConfig.getName(), inferenceConfig);
115+
builder.endObject();
116+
}
117+
return builder;
118+
}
119+
120+
@Override
121+
public String getWriteableName() {
122+
return NAME;
123+
}
124+
125+
@Override
126+
public int hashCode() {
127+
return Objects.hash(super.hashCode(), bucketPathMap, modelId, inferenceConfig);
128+
}
129+
130+
@Override
131+
public boolean equals(Object obj) {
132+
if (this == obj) return true;
133+
if (obj == null || getClass() != obj.getClass()) return false;
134+
if (super.equals(obj) == false) return false;
135+
136+
InferencePipelineAggregationBuilder other = (InferencePipelineAggregationBuilder) obj;
137+
return Objects.equals(bucketPathMap, other.bucketPathMap)
138+
&& Objects.equals(modelId, other.modelId)
139+
&& Objects.equals(inferenceConfig, other.inferenceConfig);
140+
}
141+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.analytics;
21+
22+
import org.elasticsearch.client.ml.inference.results.FeatureImportance;
23+
import org.elasticsearch.client.ml.inference.results.TopClassEntry;
24+
import org.elasticsearch.common.ParseField;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.ObjectParser;
27+
import org.elasticsearch.common.xcontent.XContentBuilder;
28+
import org.elasticsearch.common.xcontent.XContentParseException;
29+
import org.elasticsearch.common.xcontent.XContentParser;
30+
import org.elasticsearch.search.aggregations.ParsedAggregation;
31+
32+
import java.io.IOException;
33+
import java.util.List;
34+
35+
/**
36+
* This class parses the superset of all possible fields that may be written by
37+
* InferenceResults. The warning field is mutually exclusive with all the other fields.
38+
*
39+
* In the case of classification results {@link #getValue()} may return a String,
40+
* Boolean or a Double. For regression results {@link #getValue()} is always
41+
* a Double.
42+
*/
43+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
44+
45+
public class ParsedInference extends ParsedAggregation {
46+
47+
@SuppressWarnings("unchecked")
48+
private static final ConstructingObjectParser<ParsedInference, Void> PARSER =
49+
new ConstructingObjectParser<>(ParsedInference.class.getSimpleName(), true,
50+
args -> new ParsedInference(args[0], (List<FeatureImportance>) args[1],
51+
(List<TopClassEntry>) args[2], (String) args[3]));
52+
53+
public static final ParseField FEATURE_IMPORTANCE = new ParseField("feature_importance");
54+
public static final ParseField WARNING = new ParseField("warning");
55+
public static final ParseField TOP_CLASSES = new ParseField("top_classes");
56+
57+
static {
58+
PARSER.declareField(optionalConstructorArg(), (p, n) -> {
59+
Object o;
60+
XContentParser.Token token = p.currentToken();
61+
if (token == XContentParser.Token.VALUE_STRING) {
62+
o = p.text();
63+
} else if (token == XContentParser.Token.VALUE_BOOLEAN) {
64+
o = p.booleanValue();
65+
} else if (token == XContentParser.Token.VALUE_NUMBER) {
66+
o = p.doubleValue();
67+
} else {
68+
throw new XContentParseException(p.getTokenLocation(),
69+
"[" + ParsedInference.class.getSimpleName() + "] failed to parse field [" + CommonFields.VALUE + "] "
70+
+ "value [" + token + "] is not a string, boolean or number");
71+
}
72+
return o;
73+
}, CommonFields.VALUE, ObjectParser.ValueType.VALUE);
74+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FeatureImportance.fromXContent(p), FEATURE_IMPORTANCE);
75+
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> TopClassEntry.fromXContent(p), TOP_CLASSES);
76+
PARSER.declareString(optionalConstructorArg(), WARNING);
77+
declareAggregationFields(PARSER);
78+
}
79+
80+
public static ParsedInference fromXContent(XContentParser parser, final String name) {
81+
ParsedInference parsed = PARSER.apply(parser, null);
82+
parsed.setName(name);
83+
return parsed;
84+
}
85+
86+
private final Object value;
87+
private final List<FeatureImportance> featureImportance;
88+
private final List<TopClassEntry> topClasses;
89+
private final String warning;
90+
91+
ParsedInference(Object value,
92+
List<FeatureImportance> featureImportance,
93+
List<TopClassEntry> topClasses,
94+
String warning) {
95+
this.value = value;
96+
this.warning = warning;
97+
this.featureImportance = featureImportance;
98+
this.topClasses = topClasses;
99+
}
100+
101+
public Object getValue() {
102+
return value;
103+
}
104+
105+
public List<FeatureImportance> getFeatureImportance() {
106+
return featureImportance;
107+
}
108+
109+
public List<TopClassEntry> getTopClasses() {
110+
return topClasses;
111+
}
112+
113+
public String getWarning() {
114+
return warning;
115+
}
116+
117+
@Override
118+
protected XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
119+
if (warning != null) {
120+
builder.field(WARNING.getPreferredName(), warning);
121+
} else {
122+
builder.field(CommonFields.VALUE.getPreferredName(), value);
123+
if (topClasses != null && topClasses.size() > 0) {
124+
builder.field(TOP_CLASSES.getPreferredName(), topClasses);
125+
}
126+
if (featureImportance != null && featureImportance.size() > 0) {
127+
builder.field(FEATURE_IMPORTANCE.getPreferredName(), featureImportance);
128+
}
129+
}
130+
return builder;
131+
}
132+
133+
@Override
134+
public String getType() {
135+
return InferencePipelineAggregationBuilder.NAME;
136+
}
137+
}

0 commit comments

Comments
 (0)