Skip to content

Commit 1380dd4

Browse files
authored
[7.x] [ML][Inference] Fix weighted mode definition (#51648) (#51695)
* [ML][Inference] Fix weighted mode definition (#51648) Weighted mode inaccurately assumed that the "max value" of the input values would be the maximum class value. This does not make sense. Weighted Mode should know how many classes there are. Hence the new parameter `num_classes`. This indicates what the maximum class value to be expected.
1 parent 69ef9b0 commit 1380dd4

File tree

9 files changed

+220
-39
lines changed

9 files changed

+220
-39
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedMode.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ public class WeightedMode implements OutputAggregator {
3434

3535
public static final String NAME = "weighted_mode";
3636
public static final ParseField WEIGHTS = new ParseField("weights");
37+
public static final ParseField NUM_CLASSES = new ParseField("num_classes");
3738

3839
@SuppressWarnings("unchecked")
3940
private static final ConstructingObjectParser<WeightedMode, Void> PARSER = new ConstructingObjectParser<>(
4041
NAME,
4142
true,
42-
a -> new WeightedMode((List<Double>)a[0]));
43+
a -> new WeightedMode((Integer)a[0], (List<Double>)a[1]));
4344
static {
45+
PARSER.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
4446
PARSER.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
4547
}
4648

@@ -49,9 +51,11 @@ public static WeightedMode fromXContent(XContentParser parser) {
4951
}
5052

5153
private final List<Double> weights;
54+
private final int numClasses;
5255

53-
public WeightedMode(List<Double> weights) {
56+
public WeightedMode(int numClasses, List<Double> weights) {
5457
this.weights = weights;
58+
this.numClasses = numClasses;
5559
}
5660

5761
@Override
@@ -65,6 +69,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
6569
if (weights != null) {
6670
builder.field(WEIGHTS.getPreferredName(), weights);
6771
}
72+
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
6873
builder.endObject();
6974
return builder;
7075
}
@@ -74,11 +79,11 @@ public boolean equals(Object o) {
7479
if (this == o) return true;
7580
if (o == null || getClass() != o.getClass()) return false;
7681
WeightedMode that = (WeightedMode) o;
77-
return Objects.equals(weights, that.weights);
82+
return Objects.equals(weights, that.weights) && numClasses == that.numClasses;
7883
}
7984

8085
@Override
8186
public int hashCode() {
82-
return Objects.hash(weights);
87+
return Objects.hash(weights, numClasses);
8388
}
8489
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/EnsembleTests.java

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
import java.io.IOException;
3333
import java.util.ArrayList;
34-
import java.util.Arrays;
3534
import java.util.Collections;
3635
import java.util.List;
3736
import java.util.function.Predicate;
@@ -69,17 +68,17 @@ public static Ensemble createRandom(TargetType targetType) {
6968
List<TrainedModel> models = Stream.generate(() -> TreeTests.buildRandomTree(featureNames, 6, targetType))
7069
.limit(numberOfModels)
7170
.collect(Collectors.toList());
72-
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
73-
List<OutputAggregator> possibleAggregators = new ArrayList<>(Arrays.asList(new WeightedMode(weights),
74-
new LogisticRegression(weights)));
75-
if (targetType.equals(TargetType.REGRESSION)) {
76-
possibleAggregators.add(new WeightedSum(weights));
77-
}
78-
OutputAggregator outputAggregator = randomFrom(possibleAggregators.toArray(new OutputAggregator[0]));
7971
List<String> categoryLabels = null;
8072
if (randomBoolean() && targetType.equals(TargetType.CLASSIFICATION)) {
81-
categoryLabels = Arrays.asList(generateRandomStringArray(randomIntBetween(1, 10), randomIntBetween(1, 10), false, false));
73+
categoryLabels = randomList(2, randomIntBetween(3, 10), () -> randomAlphaOfLength(10));
8274
}
75+
List<Double> weights = Stream.generate(ESTestCase::randomDouble).limit(numberOfModels).collect(Collectors.toList());
76+
OutputAggregator outputAggregator = targetType == TargetType.REGRESSION ? new WeightedSum(weights) :
77+
randomFrom(
78+
new WeightedMode(
79+
categoryLabels != null ? categoryLabels.size() : randomIntBetween(2, 10),
80+
weights),
81+
new LogisticRegression(weights));
8382
double[] thresholds = randomBoolean() && targetType == TargetType.CLASSIFICATION ?
8483
Stream.generate(ESTestCase::randomDouble)
8584
.limit(categoryLabels == null ? randomIntBetween(1, 10) : categoryLabels.size())

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/inference/trainedmodel/ensemble/WeightedModeTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
public class WeightedModeTests extends AbstractXContentTestCase<WeightedMode> {
3131

3232
WeightedMode createTestInstance(int numberOfWeights) {
33-
return new WeightedMode(Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
33+
return new WeightedMode(
34+
randomIntBetween(2, 10),
35+
Stream.generate(ESTestCase::randomDouble).limit(numberOfWeights).collect(Collectors.toList()));
3436
}
3537

3638
@Override
@@ -45,7 +47,7 @@ protected boolean supportsUnknownFields() {
4547

4648
@Override
4749
protected WeightedMode createTestInstance() {
48-
return randomBoolean() ? new WeightedMode(null) : createTestInstance(randomIntBetween(1, 100));
50+
return randomBoolean() ? new WeightedMode(randomIntBetween(2, 10), null) : createTestInstance(randomIntBetween(1, 100));
4951
}
5052

5153
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.java

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.common.xcontent.XContentBuilder;
1515
import org.elasticsearch.common.xcontent.XContentParser;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
17+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1718

1819
import java.io.IOException;
1920
import java.util.ArrayList;
@@ -29,6 +30,7 @@ public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyPa
2930
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
3031
public static final ParseField NAME = new ParseField("weighted_mode");
3132
public static final ParseField WEIGHTS = new ParseField("weights");
33+
public static final ParseField NUM_CLASSES = new ParseField("num_classes");
3234

3335
private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
3436
private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
@@ -38,7 +40,8 @@ private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean
3840
ConstructingObjectParser<WeightedMode, Void> parser = new ConstructingObjectParser<>(
3941
NAME.getPreferredName(),
4042
lenient,
41-
a -> new WeightedMode((List<Double>)a[0]));
43+
a -> new WeightedMode((Integer) a[0], (List<Double>)a[1]));
44+
parser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
4245
parser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
4346
return parser;
4447
}
@@ -52,17 +55,23 @@ public static WeightedMode fromXContentLenient(XContentParser parser) {
5255
}
5356

5457
private final double[] weights;
58+
private final int numClasses;
5559

56-
WeightedMode() {
57-
this((List<Double>) null);
60+
WeightedMode(int numClasses) {
61+
this(numClasses, null);
5862
}
5963

60-
private WeightedMode(List<Double> weights) {
61-
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray());
64+
private WeightedMode(Integer numClasses, List<Double> weights) {
65+
this(weights == null ? null : weights.stream().mapToDouble(Double::valueOf).toArray(), numClasses);
6266
}
6367

64-
public WeightedMode(double[] weights) {
68+
public WeightedMode(double[] weights, Integer numClasses) {
6569
this.weights = weights;
70+
this.numClasses = ExceptionsHelper.requireNonNull(numClasses, NUM_CLASSES);
71+
if (this.numClasses <= 1) {
72+
throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1.");
73+
}
74+
6675
}
6776

6877
public WeightedMode(StreamInput in) throws IOException {
@@ -71,6 +80,7 @@ public WeightedMode(StreamInput in) throws IOException {
7180
} else {
7281
this.weights = null;
7382
}
83+
this.numClasses = in.readVInt();
7484
}
7585

7686
@Override
@@ -99,7 +109,10 @@ public List<Double> processValues(List<Double> values) {
99109
maxVal = integerValue;
100110
}
101111
}
102-
List<Double> frequencies = new ArrayList<>(Collections.nCopies(maxVal + 1, Double.NEGATIVE_INFINITY));
112+
if (maxVal >= numClasses) {
113+
throw new IllegalArgumentException("values contain entries larger than expected max of [" + (numClasses - 1) + "]");
114+
}
115+
List<Double> frequencies = new ArrayList<>(Collections.nCopies(numClasses, Double.NEGATIVE_INFINITY));
103116
for (int i = 0; i < freqArray.size(); i++) {
104117
Double weight = weights == null ? 1.0 : weights[i];
105118
Integer value = freqArray.get(i);
@@ -133,7 +146,7 @@ public String getName() {
133146

134147
@Override
135148
public boolean compatibleWith(TargetType targetType) {
136-
return true;
149+
return targetType.equals(TargetType.CLASSIFICATION);
137150
}
138151

139152
@Override
@@ -147,6 +160,7 @@ public void writeTo(StreamOutput out) throws IOException {
147160
if (weights != null) {
148161
out.writeDoubleArray(weights);
149162
}
163+
out.writeVInt(numClasses);
150164
}
151165

152166
@Override
@@ -155,6 +169,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
155169
if (weights != null) {
156170
builder.field(WEIGHTS.getPreferredName(), weights);
157171
}
172+
builder.field(NUM_CLASSES.getPreferredName(), numClasses);
158173
builder.endObject();
159174
return builder;
160175
}
@@ -164,12 +179,12 @@ public boolean equals(Object o) {
164179
if (this == o) return true;
165180
if (o == null || getClass() != o.getClass()) return false;
166181
WeightedMode that = (WeightedMode) o;
167-
return Arrays.equals(weights, that.weights);
182+
return Arrays.equals(weights, that.weights) && numClasses == that.numClasses;
168183
}
169184

170185
@Override
171186
public int hashCode() {
172-
return Arrays.hashCode(weights);
187+
return Objects.hash(Arrays.hashCode(weights), numClasses);
173188
}
174189

175190
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinitionTests.java

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import org.elasticsearch.xpack.core.ml.inference.preprocessing.FrequencyEncodingTests;
1919
import org.elasticsearch.xpack.core.ml.inference.preprocessing.OneHotEncodingTests;
2020
import org.elasticsearch.xpack.core.ml.inference.preprocessing.TargetMeanEncodingTests;
21+
import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults;
22+
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
2123
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.Ensemble;
2224
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.EnsembleTests;
2325
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.tree.Tree;
@@ -26,7 +28,9 @@
2628
import java.io.IOException;
2729
import java.util.ArrayList;
2830
import java.util.Collections;
31+
import java.util.HashMap;
2932
import java.util.List;
33+
import java.util.Map;
3034
import java.util.function.Predicate;
3135
import java.util.stream.Collectors;
3236
import java.util.stream.Stream;
@@ -302,4 +306,70 @@ public void testRamUsageEstimation() {
302306
assertThat(test.ramBytesUsed(), greaterThan(0L));
303307
}
304308

309+
public void testMultiClassIrisInference() throws IOException {
310+
// Fairly simple, random forest classification model built to fit in our format
311+
// Trained on the well known Iris dataset
312+
String compressedDef = "H4sIAPbiMl4C/+1b246bMBD9lVWet8jjG3b/oN9QVYgmToLEkghIL6r23wukl90" +
313+
"YxRMGlt2WPKwEC/gYe2bOnBl+rOoyzQq3SR4OG5ev3t/9WLmicg+fc9cd1Gm5c3VSfz+2x6t1nlZVts3Wa" +
314+
"Z0ditX93Wrr0vpUuqRIH1zVXPJxVbljmie5K3b1vr3ifPw125wPj65+9u/z8fnfn+4vh0jy9LPLzw/+UGb" +
315+
"Vu8rVhyptb+wOv7iyytaH/FD+PZWVu6xo7u8e92x+3XOaSZVurtm1QydVXZ7W7XPPcIoGWpIVG/etOWbNR" +
316+
"Ru3zqp28r+B5bVrH5a7bZ2s91m+aU5Cc6LMdvu/Z3gL55hndfILdnNOtGPuS1ftD901LDKs+wFYziy3j/d" +
317+
"3FwjgKoJ0m3xJ81N7kvn3cix64aEH1gOfX8CXkVEtemFAahvz2IcgsBCkB0GhEMTKH1Ri3xn49yosYO0Bj" +
318+
"hErDpGy3Y9JLbjSRvoQNAF+jIVvPPi2Bz67gK8iK1v0ptmsWoHoWXFDQG+x9/IeQ8Hbqm+swBGT15dr1wM" +
319+
"CKDNA2yv0GKxE7b4+cwFBWDKQ+BlfDSgsat43tH94xD49diMtoeEVhgaN2mi6iwzMKqFjKUDPEBqCrmq6O" +
320+
"HHd0PViMreajEEFJxlaccAi4B4CgdhzHBHdOcFqCSYTI14g2WS2z0007DfAe4Hy7DdkrI2I+9yGIhitJhh" +
321+
"tTBjXYN+axcX1Ab7Oom2P+RgAtffDLj/A0a5vfkAbL/jWCwJHj9jT3afMzSQtQJYEhR6ibQ984+McsYQqg" +
322+
"m4baTBKMB6LHhDo/Aj8BInDcI6q0ePG/rgMx+57hkXnU+AnVGBxCWH3zq3ijclwI/tW3lC2jSVsWM4oN1O" +
323+
"SIc4XkjRGXjGEosylOUkUQ7AhhkBgSXYc1YvAksw4PG1kGWsAT5tOxbruOKbTnwIkSYxD1MbXsWAIUwMKz" +
324+
"eGUeDUbRwI9Fkek5CiwqAM3Bz6NUgdUt+vBslhIo8UM6kDQac4kDiicpHfe+FwY2SQI5q3oadvnoQ3hMHE" +
325+
"pCaHUgkqoVcRCG5aiKzCUCN03cUtJ4ikJxZTVlcWvDvarL626DiiVLH71pf0qG1y9H7mEPSQBNoTtQpFba" +
326+
"NzfDFfXSNJqPFJBkFb/1iiNLxhSAW3u4Ns7qHHi+i1F9fmyj1vV0sDIZonP0wh+waxjLr1vOPcmxORe7n3" +
327+
"pKOKIhVp9Rtb4+Owa3xCX/TpFPnrig6nKTNisNl8aNEKQRfQITh9kG/NhTzcvpwRZoARZvkh8S6h7Oz1zI" +
328+
"atZeuYWk5nvC4TJ2aFFJXBCTkcO9UuQQ0qb3FXdx4xTPH6dBeApP0CQ43QejN8kd7l64jI1krMVgJfPEf7" +
329+
"h3uq3o/K/ztZqP1QKFagz/G+t1XxwjeIFuqkRbXoTdlOTGnwCIoKZ6ku1AbrBoN6oCdX56w3UEOO0y2B9g" +
330+
"aLbAYWcAdpeweKa2IfIT2jz5QzXxD6AoP+DrdXtxeluV7pdWrvkcKqPp7rjS19d+wp/fff/5Ez3FPjzFNy" +
331+
"fdpTi9JB0sDp2JR7b309mn5HuPkEAAA==";
332+
333+
TrainedModelDefinition definition = InferenceToXContentCompressor.inflate(compressedDef,
334+
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
335+
xContentRegistry());
336+
337+
Map<String, Object> fields = new HashMap<String, Object>(){{
338+
put("sepal_length", 5.1);
339+
put("sepal_width", 3.5);
340+
put("petal_length", 1.4);
341+
put("petal_width", 0.2);
342+
}};
343+
344+
assertThat(
345+
((ClassificationInferenceResults)definition.getTrainedModel()
346+
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
347+
.getClassificationLabel(),
348+
equalTo("Iris-setosa"));
349+
350+
fields = new HashMap<String, Object>(){{
351+
put("sepal_length", 7.0);
352+
put("sepal_width", 3.2);
353+
put("petal_length", 4.7);
354+
put("petal_width", 1.4);
355+
}};
356+
assertThat(
357+
((ClassificationInferenceResults)definition.getTrainedModel()
358+
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
359+
.getClassificationLabel(),
360+
equalTo("Iris-versicolor"));
361+
362+
fields = new HashMap<String, Object>(){{
363+
put("sepal_length", 6.5);
364+
put("sepal_width", 3.0);
365+
put("petal_length", 5.2);
366+
put("petal_width", 2.0);
367+
}};
368+
assertThat(
369+
((ClassificationInferenceResults)definition.getTrainedModel()
370+
.infer(fields, ClassificationConfig.EMPTY_PARAMS))
371+
.getClassificationLabel(),
372+
equalTo("Iris-virginica"));
373+
}
374+
305375
}

0 commit comments

Comments
 (0)