Skip to content

Commit 92a6dd4

Browse files
committed
Apply review comments
1 parent 99a8c32 commit 92a6dd4

File tree

6 files changed

+162
-67
lines changed

6 files changed

+162
-67
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java

+50-34
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
import java.util.List;
3333
import java.util.Objects;
3434

35-
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
3635
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
3736

3837
/**
@@ -103,23 +102,22 @@ public static class Result implements EvaluationMetric.Result {
103102
@SuppressWarnings("unchecked")
104103
private static final ConstructingObjectParser<Result, Void> PARSER =
105104
new ConstructingObjectParser<>(
106-
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (long) a[1]));
105+
"multiclass_confusion_matrix_result", true, a -> new Result((List<ActualClass>) a[0], (Long) a[1]));
107106

108107
static {
109-
PARSER.declareObjectArray(constructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
110-
PARSER.declareLong(constructorArg(), OTHER_ACTUAL_CLASS_COUNT);
108+
PARSER.declareObjectArray(optionalConstructorArg(), ActualClass.PARSER, CONFUSION_MATRIX);
109+
PARSER.declareLong(optionalConstructorArg(), OTHER_ACTUAL_CLASS_COUNT);
111110
}
112111

113112
public static Result fromXContent(XContentParser parser) {
114113
return PARSER.apply(parser, null);
115114
}
116115

117-
// Immutable
118116
private final List<ActualClass> confusionMatrix;
119-
private final long otherActualClassCount;
117+
private final Long otherActualClassCount;
120118

121-
public Result(List<ActualClass> confusionMatrix, long otherActualClassCount) {
122-
this.confusionMatrix = Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix));
119+
public Result(@Nullable List<ActualClass> confusionMatrix, @Nullable Long otherActualClassCount) {
120+
this.confusionMatrix = confusionMatrix != null ? Collections.unmodifiableList(Objects.requireNonNull(confusionMatrix)) : null;
123121
this.otherActualClassCount = otherActualClassCount;
124122
}
125123

@@ -132,15 +130,19 @@ public List<ActualClass> getConfusionMatrix() {
132130
return confusionMatrix;
133131
}
134132

135-
public long getOtherActualClassCount() {
133+
public Long getOtherActualClassCount() {
136134
return otherActualClassCount;
137135
}
138136

139137
@Override
140138
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
141139
builder.startObject();
142-
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
143-
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
140+
if (confusionMatrix != null) {
141+
builder.field(CONFUSION_MATRIX.getPreferredName(), confusionMatrix);
142+
}
143+
if (otherActualClassCount != null) {
144+
builder.field(OTHER_ACTUAL_CLASS_COUNT.getPreferredName(), otherActualClassCount);
145+
}
144146
builder.endObject();
145147
return builder;
146148
}
@@ -151,7 +153,7 @@ public boolean equals(Object o) {
151153
if (o == null || getClass() != o.getClass()) return false;
152154
Result that = (Result) o;
153155
return Objects.equals(this.confusionMatrix, that.confusionMatrix)
154-
&& this.otherActualClassCount == that.otherActualClassCount;
156+
&& Objects.equals(this.otherActualClassCount, that.otherActualClassCount);
155157
}
156158

157159
@Override
@@ -172,35 +174,45 @@ public static class ActualClass implements ToXContentObject {
172174
new ConstructingObjectParser<>(
173175
"multiclass_confusion_matrix_actual_class",
174176
true,
175-
a -> new ActualClass((String) a[0], (long) a[1], (List<PredictedClass>) a[2], (long) a[3]));
177+
a -> new ActualClass((String) a[0], (Long) a[1], (List<PredictedClass>) a[2], (Long) a[3]));
176178

177179
static {
178-
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
179-
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
180-
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
181-
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
180+
PARSER.declareString(optionalConstructorArg(), ACTUAL_CLASS);
181+
PARSER.declareLong(optionalConstructorArg(), ACTUAL_CLASS_DOC_COUNT);
182+
PARSER.declareObjectArray(optionalConstructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
183+
PARSER.declareLong(optionalConstructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
182184
}
183185

184186
private final String actualClass;
185-
private final long actualClassDocCount;
187+
private final Long actualClassDocCount;
186188
private final List<PredictedClass> predictedClasses;
187-
private final long otherPredictedClassDocCount;
189+
private final Long otherPredictedClassDocCount;
188190

189-
public ActualClass(
190-
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
191+
public ActualClass(@Nullable String actualClass,
192+
@Nullable Long actualClassDocCount,
193+
@Nullable List<PredictedClass> predictedClasses,
194+
@Nullable Long otherPredictedClassDocCount) {
191195
this.actualClass = actualClass;
192196
this.actualClassDocCount = actualClassDocCount;
193-
this.predictedClasses = Collections.unmodifiableList(predictedClasses);
197+
this.predictedClasses = predictedClasses != null ? Collections.unmodifiableList(predictedClasses) : null;
194198
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
195199
}
196200

197201
@Override
198202
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
199203
builder.startObject();
200-
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
201-
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
202-
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
203-
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
204+
if (actualClass != null) {
205+
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
206+
}
207+
if (actualClassDocCount != null) {
208+
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
209+
}
210+
if (predictedClasses != null) {
211+
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
212+
}
213+
if (otherPredictedClassDocCount != null) {
214+
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
215+
}
204216
builder.endObject();
205217
return builder;
206218
}
@@ -211,9 +223,9 @@ public boolean equals(Object o) {
211223
if (o == null || getClass() != o.getClass()) return false;
212224
ActualClass that = (ActualClass) o;
213225
return Objects.equals(this.actualClass, that.actualClass)
214-
&& this.actualClassDocCount == that.actualClassDocCount
226+
&& Objects.equals(this.actualClassDocCount, that.actualClassDocCount)
215227
&& Objects.equals(this.predictedClasses, that.predictedClasses)
216-
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
228+
&& Objects.equals(this.otherPredictedClassDocCount, that.otherPredictedClassDocCount);
217229
}
218230

219231
@Override
@@ -235,26 +247,30 @@ public static class PredictedClass implements ToXContentObject {
235247
@SuppressWarnings("unchecked")
236248
private static final ConstructingObjectParser<PredictedClass, Void> PARSER =
237249
new ConstructingObjectParser<>(
238-
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (long) a[1]));
250+
"multiclass_confusion_matrix_predicted_class", true, a -> new PredictedClass((String) a[0], (Long) a[1]));
239251

240252
static {
241-
PARSER.declareString(constructorArg(), PREDICTED_CLASS);
242-
PARSER.declareLong(constructorArg(), COUNT);
253+
PARSER.declareString(optionalConstructorArg(), PREDICTED_CLASS);
254+
PARSER.declareLong(optionalConstructorArg(), COUNT);
243255
}
244256

245257
private final String predictedClass;
246258
private final Long count;
247259

248-
public PredictedClass(String predictedClass, Long count) {
260+
public PredictedClass(@Nullable String predictedClass, @Nullable Long count) {
249261
this.predictedClass = predictedClass;
250262
this.count = count;
251263
}
252264

253265
@Override
254266
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
255267
builder.startObject();
256-
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
257-
builder.field(COUNT.getPreferredName(), count);
268+
if (predictedClass != null) {
269+
builder.field(PREDICTED_CLASS.getPreferredName(), predictedClass);
270+
}
271+
if (count != null) {
272+
builder.field(COUNT.getPreferredName(), count);
273+
}
258274
builder.endObject();
259275
return builder;
260276
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+8-8
Original file line numberDiff line numberDiff line change
@@ -1805,19 +1805,19 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18051805
List.of(
18061806
new ActualClass(
18071807
"ant",
1808-
1,
1808+
1L,
18091809
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
1810-
0),
1810+
0L),
18111811
new ActualClass(
18121812
"cat",
1813-
5,
1813+
5L,
18141814
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
1815-
1),
1815+
1L),
18161816
new ActualClass(
18171817
"dog",
1818-
4,
1818+
4L,
18191819
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
1820-
0))));
1820+
0L))));
18211821
assertThat(mcmResult.getOtherActualClassCount(), equalTo(0L));
18221822
}
18231823
{ // Explicit size provided for MulticlassConfusionMatrixMetric metric
@@ -1839,8 +1839,8 @@ public void testEvaluateDataFrame_Classification() throws IOException {
18391839
mcmResult.getConfusionMatrix(),
18401840
equalTo(
18411841
List.of(
1842-
new ActualClass("cat", 5, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1),
1843-
new ActualClass("dog", 4, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0)
1842+
new ActualClass("cat", 5L, List.of(new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)), 1L),
1843+
new ActualClass("dog", 4L, List.of(new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)), 0L)
18441844
)));
18451845
assertThat(mcmResult.getOtherActualClassCount(), equalTo(1L));
18461846
}

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+6-6
Original file line numberDiff line numberDiff line change
@@ -3368,19 +3368,19 @@ public void testEvaluateDataFrame_Classification() throws Exception {
33683368
List.of(
33693369
new ActualClass(
33703370
"ant",
3371-
1,
3371+
1L,
33723372
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 0L)),
3373-
0),
3373+
0L),
33743374
new ActualClass(
33753375
"cat",
3376-
5,
3376+
5L,
33773377
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
3378-
1),
3378+
1L),
33793379
new ActualClass(
33803380
"dog",
3381-
4,
3381+
4L,
33823382
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 1L), new PredictedClass("dog", 3L)),
3383-
0))));
3383+
0L))));
33843384
assertThat(otherClassesCount, equalTo(0L));
33853385
}
33863386
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,16 @@ protected Result createTestInstance() {
4848
for (int i = 0; i < numClasses; i++) {
4949
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
5050
for (int j = 0; j < numClasses; j++) {
51-
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
51+
predictedClasses.add(new PredictedClass(classNames.get(j), randomBoolean() ? randomNonNegativeLong() : null));
5252
}
53-
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
53+
actualClasses.add(
54+
new ActualClass(
55+
classNames.get(i),
56+
randomBoolean() ? randomNonNegativeLong() : null,
57+
predictedClasses,
58+
randomBoolean() ? randomNonNegativeLong() : null));
5459
}
55-
return new Result(actualClasses, randomNonNegativeLong());
60+
return new Result(actualClasses, randomBoolean() ? randomNonNegativeLong() : null);
5661
}
5762

5863
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

+23-16
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,13 @@ public static Result fromXContent(XContentParser parser) {
220220
private final long otherActualClassCount;
221221

222222
public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
223-
this.actualClasses = Collections.unmodifiableList(Objects.requireNonNull(actualClasses));
224-
this.otherActualClassCount = otherActualClassCount;
223+
this.actualClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(actualClasses, CONFUSION_MATRIX));
224+
this.otherActualClassCount = requireNonNegative(otherActualClassCount, OTHER_ACTUAL_CLASS_COUNT);
225225
}
226226

227227
public Result(StreamInput in) throws IOException {
228228
this.actualClasses = Collections.unmodifiableList(in.readList(ActualClass::new));
229-
this.otherActualClassCount = in.readLong();
229+
this.otherActualClassCount = in.readVLong();
230230
}
231231

232232
@Override
@@ -250,7 +250,7 @@ public long getOtherActualClassCount() {
250250
@Override
251251
public void writeTo(StreamOutput out) throws IOException {
252252
out.writeList(actualClasses);
253-
out.writeLong(otherActualClassCount);
253+
out.writeVLong(otherActualClassCount);
254254
}
255255

256256
@Override
@@ -309,25 +309,25 @@ public static class ActualClass implements ToXContentObject, Writeable {
309309

310310
public ActualClass(
311311
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
312-
this.actualClass = actualClass;
313-
this.actualClassDocCount = actualClassDocCount;
314-
this.predictedClasses = Collections.unmodifiableList(predictedClasses);
315-
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
312+
this.actualClass = ExceptionsHelper.requireNonNull(actualClass, ACTUAL_CLASS);
313+
this.actualClassDocCount = requireNonNegative(actualClassDocCount, ACTUAL_CLASS_DOC_COUNT);
314+
this.predictedClasses = Collections.unmodifiableList(ExceptionsHelper.requireNonNull(predictedClasses, PREDICTED_CLASSES));
315+
this.otherPredictedClassDocCount = requireNonNegative(otherPredictedClassDocCount, OTHER_PREDICTED_CLASS_DOC_COUNT);
316316
}
317317

318318
public ActualClass(StreamInput in) throws IOException {
319319
this.actualClass = in.readString();
320-
this.actualClassDocCount = in.readLong();
320+
this.actualClassDocCount = in.readVLong();
321321
this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
322-
this.otherPredictedClassDocCount = in.readLong();
322+
this.otherPredictedClassDocCount = in.readVLong();
323323
}
324324

325325
@Override
326326
public void writeTo(StreamOutput out) throws IOException {
327327
out.writeString(actualClass);
328-
out.writeLong(actualClassDocCount);
328+
out.writeVLong(actualClassDocCount);
329329
out.writeList(predictedClasses);
330-
out.writeLong(otherPredictedClassDocCount);
330+
out.writeVLong(otherPredictedClassDocCount);
331331
}
332332

333333
@Override
@@ -377,13 +377,13 @@ public static class PredictedClass implements ToXContentObject, Writeable {
377377
private final long count;
378378

379379
public PredictedClass(String predictedClass, long count) {
380-
this.predictedClass = predictedClass;
381-
this.count = count;
380+
this.predictedClass = ExceptionsHelper.requireNonNull(predictedClass, PREDICTED_CLASS);
381+
this.count = requireNonNegative(count, COUNT);
382382
}
383383

384384
public PredictedClass(StreamInput in) throws IOException {
385385
this.predictedClass = in.readString();
386-
this.count = in.readLong();
386+
this.count = in.readVLong();
387387
}
388388

389389
public String getPredictedClass() {
@@ -393,7 +393,7 @@ public String getPredictedClass() {
393393
@Override
394394
public void writeTo(StreamOutput out) throws IOException {
395395
out.writeString(predictedClass);
396-
out.writeLong(count);
396+
out.writeVLong(count);
397397
}
398398

399399
@Override
@@ -419,4 +419,11 @@ public int hashCode() {
419419
return Objects.hash(predictedClass, count);
420420
}
421421
}
422+
423+
private static long requireNonNegative(long value, ParseField field) {
424+
if (value < 0) {
425+
throw ExceptionsHelper.serverError("[" + field.getPreferredName() + "] must be >= 0, was: " + value);
426+
}
427+
return value;
428+
}
422429
}

0 commit comments

Comments
 (0)