Skip to content

Commit b551f75

Browse files
authored
[ML] add new custom field to trained model processors (#59542)
This commit adds the new configurable field `custom`. `custom` indicates if the preprocessor was submitted by a user or automatically created by the analytics job. Eventually, this field will be used in calculating feature importance. When `custom` is true, the feature importance for the processed fields is calculated. When `false` the current behavior is the same (we calculate the importance for the originating field/feature). This also adds new required methods to the preprocessor interface. If users are to supply their own preprocessors in the analytics job configuration, we need to know the input and output field names.
1 parent 647a413 commit b551f75

File tree

18 files changed

+433
-185
lines changed

18 files changed

+433
-185
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/FrequencyEncoding.java

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,18 +40,20 @@ public class FrequencyEncoding implements PreProcessor {
4040
public static final ParseField FIELD = new ParseField("field");
4141
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
4242
public static final ParseField FREQUENCY_MAP = new ParseField("frequency_map");
43+
public static final ParseField CUSTOM = new ParseField("custom");
4344

4445
@SuppressWarnings("unchecked")
4546
public static final ConstructingObjectParser<FrequencyEncoding, Void> PARSER = new ConstructingObjectParser<>(
4647
NAME,
4748
true,
48-
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2]));
49+
a -> new FrequencyEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Boolean)a[3]));
4950
static {
5051
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
5152
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
5253
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
5354
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
5455
FREQUENCY_MAP);
56+
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
5557
}
5658

5759
public static FrequencyEncoding fromXContent(XContentParser parser) {
@@ -61,11 +63,13 @@ public static FrequencyEncoding fromXContent(XContentParser parser) {
6163
private final String field;
6264
private final String featureName;
6365
private final Map<String, Double> frequencyMap;
66+
private final Boolean custom;
6467

65-
public FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap) {
68+
FrequencyEncoding(String field, String featureName, Map<String, Double> frequencyMap, Boolean custom) {
6669
this.field = Objects.requireNonNull(field);
6770
this.featureName = Objects.requireNonNull(featureName);
6871
this.frequencyMap = Collections.unmodifiableMap(Objects.requireNonNull(frequencyMap));
72+
this.custom = custom;
6973
}
7074

7175
/**
@@ -94,12 +98,19 @@ public String getName() {
9498
return NAME;
9599
}
96100

101+
public Boolean getCustom() {
102+
return custom;
103+
}
104+
97105
@Override
98106
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
99107
builder.startObject();
100108
builder.field(FIELD.getPreferredName(), field);
101109
builder.field(FEATURE_NAME.getPreferredName(), featureName);
102110
builder.field(FREQUENCY_MAP.getPreferredName(), frequencyMap);
111+
if (custom != null) {
112+
builder.field(CUSTOM.getPreferredName(), custom);
113+
}
103114
builder.endObject();
104115
return builder;
105116
}
@@ -111,12 +122,13 @@ public boolean equals(Object o) {
111122
FrequencyEncoding that = (FrequencyEncoding) o;
112123
return Objects.equals(field, that.field)
113124
&& Objects.equals(featureName, that.featureName)
125+
&& Objects.equals(custom, that.custom)
114126
&& Objects.equals(frequencyMap, that.frequencyMap);
115127
}
116128

117129
@Override
118130
public int hashCode() {
119-
return Objects.hash(field, featureName, frequencyMap);
131+
return Objects.hash(field, featureName, frequencyMap, custom);
120132
}
121133

122134
public Builder builder(String field) {
@@ -128,6 +140,7 @@ public static class Builder {
128140
private String field;
129141
private String featureName;
130142
private Map<String, Double> frequencyMap = new HashMap<>();
143+
private Boolean custom;
131144

132145
public Builder(String field) {
133146
this.field = field;
@@ -153,8 +166,13 @@ public Builder addFrequency(String valueName, double frequency) {
153166
return this;
154167
}
155168

169+
public Builder setCustom(boolean custom) {
170+
this.custom = custom;
171+
return this;
172+
}
173+
156174
public FrequencyEncoding build() {
157-
return new FrequencyEncoding(field, featureName, frequencyMap);
175+
return new FrequencyEncoding(field, featureName, frequencyMap, custom);
158176
}
159177
}
160178

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/OneHotEncoding.java

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,17 @@ public class OneHotEncoding implements PreProcessor {
3838
public static final String NAME = "one_hot_encoding";
3939
public static final ParseField FIELD = new ParseField("field");
4040
public static final ParseField HOT_MAP = new ParseField("hot_map");
41+
public static final ParseField CUSTOM = new ParseField("custom");
4142

4243
@SuppressWarnings("unchecked")
4344
public static final ConstructingObjectParser<OneHotEncoding, Void> PARSER = new ConstructingObjectParser<>(
4445
NAME,
4546
true,
46-
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1]));
47+
a -> new OneHotEncoding((String)a[0], (Map<String, String>)a[1], (Boolean)a[2]));
4748
static {
4849
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
4950
PARSER.declareObject(ConstructingObjectParser.constructorArg(), (p, c) -> p.mapStrings(), HOT_MAP);
51+
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
5052
}
5153

5254
public static OneHotEncoding fromXContent(XContentParser parser) {
@@ -55,12 +57,13 @@ public static OneHotEncoding fromXContent(XContentParser parser) {
5557

5658
private final String field;
5759
private final Map<String, String> hotMap;
60+
private final Boolean custom;
5861

59-
public OneHotEncoding(String field, Map<String, String> hotMap) {
62+
OneHotEncoding(String field, Map<String, String> hotMap, Boolean custom) {
6063
this.field = Objects.requireNonNull(field);
6164
this.hotMap = Collections.unmodifiableMap(Objects.requireNonNull(hotMap));
65+
this.custom = custom;
6266
}
63-
6467
/**
6568
* @return Field name on which to one hot encode
6669
*/
@@ -80,11 +83,18 @@ public String getName() {
8083
return NAME;
8184
}
8285

86+
public Boolean getCustom() {
87+
return custom;
88+
}
89+
8390
@Override
8491
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
8592
builder.startObject();
8693
builder.field(FIELD.getPreferredName(), field);
8794
builder.field(HOT_MAP.getPreferredName(), hotMap);
95+
if (custom != null) {
96+
builder.field(CUSTOM.getPreferredName(), custom);
97+
}
8898
builder.endObject();
8999
return builder;
90100
}
@@ -95,12 +105,13 @@ public boolean equals(Object o) {
95105
if (o == null || getClass() != o.getClass()) return false;
96106
OneHotEncoding that = (OneHotEncoding) o;
97107
return Objects.equals(field, that.field)
98-
&& Objects.equals(hotMap, that.hotMap);
108+
&& Objects.equals(hotMap, that.hotMap)
109+
&& Objects.equals(custom, that.custom);
99110
}
100111

101112
@Override
102113
public int hashCode() {
103-
return Objects.hash(field, hotMap);
114+
return Objects.hash(field, hotMap, custom);
104115
}
105116

106117
public Builder builder(String field) {
@@ -111,6 +122,7 @@ public static class Builder {
111122

112123
private String field;
113124
private Map<String, String> hotMap = new HashMap<>();
125+
private Boolean custom;
114126

115127
public Builder(String field) {
116128
this.field = field;
@@ -131,8 +143,13 @@ public Builder addOneHot(String valueName, String oneHotFeatureName) {
131143
return this;
132144
}
133145

146+
public Builder setCustom(boolean custom) {
147+
this.custom = custom;
148+
return this;
149+
}
150+
134151
public OneHotEncoding build() {
135-
return new OneHotEncoding(field, hotMap);
152+
return new OneHotEncoding(field, hotMap, custom);
136153
}
137154
}
138155
}

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncoding.java

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,21 @@ public class TargetMeanEncoding implements PreProcessor {
4141
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
4242
public static final ParseField TARGET_MAP = new ParseField("target_map");
4343
public static final ParseField DEFAULT_VALUE = new ParseField("default_value");
44+
public static final ParseField CUSTOM = new ParseField("custom");
4445

4546
@SuppressWarnings("unchecked")
4647
public static final ConstructingObjectParser<TargetMeanEncoding, Void> PARSER = new ConstructingObjectParser<>(
4748
NAME,
4849
true,
49-
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3]));
50+
a -> new TargetMeanEncoding((String)a[0], (String)a[1], (Map<String, Double>)a[2], (Double)a[3], (Boolean)a[4]));
5051
static {
5152
PARSER.declareString(ConstructingObjectParser.constructorArg(), FIELD);
5253
PARSER.declareString(ConstructingObjectParser.constructorArg(), FEATURE_NAME);
5354
PARSER.declareObject(ConstructingObjectParser.constructorArg(),
5455
(p, c) -> p.map(HashMap::new, XContentParser::doubleValue),
5556
TARGET_MAP);
5657
PARSER.declareDouble(ConstructingObjectParser.constructorArg(), DEFAULT_VALUE);
58+
PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), CUSTOM);
5759
}
5860

5961
public static TargetMeanEncoding fromXContent(XContentParser parser) {
@@ -64,12 +66,14 @@ public static TargetMeanEncoding fromXContent(XContentParser parser) {
6466
private final String featureName;
6567
private final Map<String, Double> meanMap;
6668
private final double defaultValue;
69+
private final Boolean custom;
6770

68-
public TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue) {
71+
TargetMeanEncoding(String field, String featureName, Map<String, Double> meanMap, Double defaultValue, Boolean custom) {
6972
this.field = Objects.requireNonNull(field);
7073
this.featureName = Objects.requireNonNull(featureName);
7174
this.meanMap = Collections.unmodifiableMap(Objects.requireNonNull(meanMap));
7275
this.defaultValue = Objects.requireNonNull(defaultValue);
76+
this.custom = custom;
7377
}
7478

7579
/**
@@ -100,6 +104,10 @@ public String getFeatureName() {
100104
return featureName;
101105
}
102106

107+
public Boolean getCustom() {
108+
return custom;
109+
}
110+
103111
@Override
104112
public String getName() {
105113
return NAME;
@@ -112,6 +120,9 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
112120
builder.field(FEATURE_NAME.getPreferredName(), featureName);
113121
builder.field(TARGET_MAP.getPreferredName(), meanMap);
114122
builder.field(DEFAULT_VALUE.getPreferredName(), defaultValue);
123+
if (custom != null) {
124+
builder.field(CUSTOM.getPreferredName(), custom);
125+
}
115126
builder.endObject();
116127
return builder;
117128
}
@@ -124,12 +135,13 @@ public boolean equals(Object o) {
124135
return Objects.equals(field, that.field)
125136
&& Objects.equals(featureName, that.featureName)
126137
&& Objects.equals(meanMap, that.meanMap)
127-
&& Objects.equals(defaultValue, that.defaultValue);
138+
&& Objects.equals(defaultValue, that.defaultValue)
139+
&& Objects.equals(custom, that.custom);
128140
}
129141

130142
@Override
131143
public int hashCode() {
132-
return Objects.hash(field, featureName, meanMap, defaultValue);
144+
return Objects.hash(field, featureName, meanMap, defaultValue, custom);
133145
}
134146

135147
public Builder builder(String field) {
@@ -142,6 +154,7 @@ public static class Builder {
142154
private String featureName;
143155
private Map<String, Double> meanMap = new HashMap<>();
144156
private double defaultValue;
157+
private Boolean custom;
145158

146159
public Builder(String field) {
147160
this.field = field;
@@ -176,8 +189,13 @@ public Builder setDefaultValue(double defaultValue) {
176189
return this;
177190
}
178191

192+
public Builder setCustom(boolean custom) {
193+
this.custom = custom;
194+
return this;
195+
}
196+
179197
public TargetMeanEncoding build() {
180-
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue);
198+
return new TargetMeanEncoding(field, featureName, meanMap, defaultValue, custom);
181199
}
182200
}
183201
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/FrequencyEncodingTests.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ public static FrequencyEncoding createRandom() {
5555
for (int i = 0; i < valuesSize; i++) {
5656
valueMap.put(randomAlphaOfLength(10), randomDoubleBetween(0.0, 1.0, false));
5757
}
58-
return new FrequencyEncoding(randomAlphaOfLength(10), randomAlphaOfLength(10), valueMap);
58+
return new FrequencyEncoding(randomAlphaOfLength(10),
59+
randomAlphaOfLength(10),
60+
valueMap,
61+
randomBoolean() ? null : randomBoolean());
5962
}
6063
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/OneHotEncodingTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ public static OneHotEncoding createRandom() {
5555
for (int i = 0; i < valuesSize; i++) {
5656
valueMap.put(randomAlphaOfLength(10), randomAlphaOfLength(10));
5757
}
58-
return new OneHotEncoding(randomAlphaOfLength(10), valueMap);
58+
return new OneHotEncoding(randomAlphaOfLength(10), valueMap, randomBoolean() ? null : randomBoolean());
5959
}
6060

6161
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/preprocessing/TargetMeanEncodingTests.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ public static TargetMeanEncoding createRandom() {
5858
return new TargetMeanEncoding(randomAlphaOfLength(10),
5959
randomAlphaOfLength(10),
6060
valueMap,
61-
randomDoubleBetween(0.0, 1.0, false));
61+
randomDoubleBetween(0.0, 1.0, false),
62+
randomBoolean() ? null : randomBoolean());
6263
}
6364

6465
}

docs/reference/ml/df-analytics/apis/put-inference.asciidoc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ The field name to encode.
9494
`frequency_map`::
9595
(Required, object map of string:double)
9696
Object that maps the field value to the frequency encoded value.
97+
98+
`custom`::
99+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
100+
97101
======
98102
//End frequency encoding
99103

@@ -112,6 +116,10 @@ The field name to encode.
112116
`hot_map`::
113117
(Required, object map of strings)
114118
String map of "field_value: one_hot_column_name".
119+
120+
`custom`::
121+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
122+
115123
======
116124
//End one hot encoding
117125

@@ -138,6 +146,10 @@ The field name to encode.
138146
`target_map`:::
139147
(Required, object map of string:double)
140148
Object that maps the field value to the target mean value.
149+
150+
`custom`::
151+
include::{es-repo-dir}/ml/ml-shared.asciidoc[tag=custom-preprocessor]
152+
141153
======
142154
//End target mean encoding
143155
=====

0 commit comments

Comments
 (0)