Skip to content

Commit 99a8c32

Browse files
committed
Rename other_predicted_class_count to other_predicted_class_doc_count
1 parent c761f63 commit 99a8c32

File tree

6 files changed

+41
-34
lines changed

6 files changed

+41
-34
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetric.java

+14-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.elasticsearch.client.ml.dataframe.evaluation.EvaluationMetric;
2222
import org.elasticsearch.common.Nullable;
2323
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.Strings;
2425
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
2526
import org.elasticsearch.common.xcontent.ToXContentObject;
2627
import org.elasticsearch.common.xcontent.XContentBuilder;
@@ -164,7 +165,7 @@ public static class ActualClass implements ToXContentObject {
164165
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
165166
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
166167
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
167-
private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count");
168+
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
168169

169170
@SuppressWarnings("unchecked")
170171
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
@@ -177,20 +178,20 @@ public static class ActualClass implements ToXContentObject {
177178
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
178179
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
179180
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
180-
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT);
181+
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
181182
}
182183

183184
private final String actualClass;
184185
private final long actualClassDocCount;
185186
private final List<PredictedClass> predictedClasses;
186-
private final long otherPredictedClassCount;
187+
private final long otherPredictedClassDocCount;
187188

188189
public ActualClass(
189-
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassCount) {
190+
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
190191
this.actualClass = actualClass;
191192
this.actualClassDocCount = actualClassDocCount;
192193
this.predictedClasses = Collections.unmodifiableList(predictedClasses);
193-
this.otherPredictedClassCount = otherPredictedClassCount;
194+
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
194195
}
195196

196197
@Override
@@ -199,7 +200,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
199200
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
200201
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
201202
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
202-
builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount);
203+
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
203204
builder.endObject();
204205
return builder;
205206
}
@@ -212,12 +213,17 @@ public boolean equals(Object o) {
212213
return Objects.equals(this.actualClass, that.actualClass)
213214
&& this.actualClassDocCount == that.actualClassDocCount
214215
&& Objects.equals(this.predictedClasses, that.predictedClasses)
215-
&& this.otherPredictedClassCount == that.otherPredictedClassCount;
216+
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
216217
}
217218

218219
@Override
219220
public int hashCode() {
220-
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount);
221+
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
222+
}
223+
224+
@Override
225+
public String toString() {
226+
return Strings.toString(this);
221227
}
222228
}
223229

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -3373,7 +3373,7 @@ public void testEvaluateDataFrame_Classification() throws Exception {
33733373
0),
33743374
new ActualClass(
33753375
"cat",
3376-
4,
3376+
5,
33773377
List.of(new PredictedClass("ant", 0L), new PredictedClass("cat", 3L), new PredictedClass("dog", 1L)),
33783378
1),
33793379
new ActualClass(

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixMetricResultTests.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ protected Result createTestInstance() {
4848
for (int i = 0; i < numClasses; i++) {
4949
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
5050
for (int j = 0; j < numClasses; j++) {
51-
if (randomBoolean()) {
52-
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
53-
}
51+
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
5452
}
5553
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
5654
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrix.java

+19-14
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,18 @@ public void process(Aggregations aggs) {
147147
long actualClassDocCount = bucket.getDocCount();
148148
Filters subAgg = bucket.getAggregations().get(STEP_2_AGGREGATE_BY_PREDICTED_CLASS);
149149
List<PredictedClass> predictedClasses = new ArrayList<>();
150-
long otherPredictedClassCount = 0;
150+
long otherPredictedClassDocCount = 0;
151151
for (Filters.Bucket subBucket : subAgg.getBuckets()) {
152152
String predictedClass = subBucket.getKeyAsString();
153153
long docCount = subBucket.getDocCount();
154154
if (OTHER_BUCKET_KEY.equals(predictedClass)) {
155-
otherPredictedClassCount = docCount;
155+
otherPredictedClassDocCount = docCount;
156156
} else {
157157
predictedClasses.add(new PredictedClass(predictedClass, docCount));
158158
}
159159
}
160160
predictedClasses.sort(comparing(PredictedClass::getPredictedClass));
161-
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount));
161+
actualClasses.add(new ActualClass(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount));
162162
}
163163
result = new Result(actualClasses, Math.max(cardinalityAgg.getValue() - size, 0));
164164
}
@@ -214,8 +214,9 @@ public static Result fromXContent(XContentParser parser) {
214214
return PARSER.apply(parser, null);
215215
}
216216

217-
// Immutable
217+
/** List of actual classes. */
218218
private final List<ActualClass> actualClasses;
219+
/** Number of actual classes that were not included in the confusion matrix because there were too many of them. */
219220
private final long otherActualClassCount;
220221

221222
public Result(List<ActualClass> actualClasses, long otherActualClassCount) {
@@ -281,7 +282,7 @@ public static class ActualClass implements ToXContentObject, Writeable {
281282
private static final ParseField ACTUAL_CLASS = new ParseField("actual_class");
282283
private static final ParseField ACTUAL_CLASS_DOC_COUNT = new ParseField("actual_class_doc_count");
283284
private static final ParseField PREDICTED_CLASSES = new ParseField("predicted_classes");
284-
private static final ParseField OTHER_PREDICTED_CLASS_COUNT = new ParseField("other_predicted_class_count");
285+
private static final ParseField OTHER_PREDICTED_CLASS_DOC_COUNT = new ParseField("other_predicted_class_doc_count");
285286

286287
@SuppressWarnings("unchecked")
287288
private static final ConstructingObjectParser<ActualClass, Void> PARSER =
@@ -294,35 +295,39 @@ public static class ActualClass implements ToXContentObject, Writeable {
294295
PARSER.declareString(constructorArg(), ACTUAL_CLASS);
295296
PARSER.declareLong(constructorArg(), ACTUAL_CLASS_DOC_COUNT);
296297
PARSER.declareObjectArray(constructorArg(), PredictedClass.PARSER, PREDICTED_CLASSES);
297-
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_COUNT);
298+
PARSER.declareLong(constructorArg(), OTHER_PREDICTED_CLASS_DOC_COUNT);
298299
}
299300

301+
/** Name of the actual class. */
300302
private final String actualClass;
303+
/** Number of documents (examples) belonging to the {code actualClass} class. */
301304
private final long actualClassDocCount;
305+
/** List of predicted classes. */
302306
private final List<PredictedClass> predictedClasses;
303-
private final long otherPredictedClassCount;
307+
/** Number of documents that were not predicted as any of the {@code predictedClasses}. */
308+
private final long otherPredictedClassDocCount;
304309

305310
public ActualClass(
306-
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassCount) {
311+
String actualClass, long actualClassDocCount, List<PredictedClass> predictedClasses, long otherPredictedClassDocCount) {
307312
this.actualClass = actualClass;
308313
this.actualClassDocCount = actualClassDocCount;
309314
this.predictedClasses = Collections.unmodifiableList(predictedClasses);
310-
this.otherPredictedClassCount = otherPredictedClassCount;
315+
this.otherPredictedClassDocCount = otherPredictedClassDocCount;
311316
}
312317

313318
public ActualClass(StreamInput in) throws IOException {
314319
this.actualClass = in.readString();
315320
this.actualClassDocCount = in.readLong();
316321
this.predictedClasses = Collections.unmodifiableList(in.readList(PredictedClass::new));
317-
this.otherPredictedClassCount = in.readLong();
322+
this.otherPredictedClassDocCount = in.readLong();
318323
}
319324

320325
@Override
321326
public void writeTo(StreamOutput out) throws IOException {
322327
out.writeString(actualClass);
323328
out.writeLong(actualClassDocCount);
324329
out.writeList(predictedClasses);
325-
out.writeLong(otherPredictedClassCount);
330+
out.writeLong(otherPredictedClassDocCount);
326331
}
327332

328333
@Override
@@ -331,7 +336,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
331336
builder.field(ACTUAL_CLASS.getPreferredName(), actualClass);
332337
builder.field(ACTUAL_CLASS_DOC_COUNT.getPreferredName(), actualClassDocCount);
333338
builder.field(PREDICTED_CLASSES.getPreferredName(), predictedClasses);
334-
builder.field(OTHER_PREDICTED_CLASS_COUNT.getPreferredName(), otherPredictedClassCount);
339+
builder.field(OTHER_PREDICTED_CLASS_DOC_COUNT.getPreferredName(), otherPredictedClassDocCount);
335340
builder.endObject();
336341
return builder;
337342
}
@@ -344,12 +349,12 @@ public boolean equals(Object o) {
344349
return Objects.equals(this.actualClass, that.actualClass)
345350
&& this.actualClassDocCount == that.actualClassDocCount
346351
&& Objects.equals(this.predictedClasses, that.predictedClasses)
347-
&& this.otherPredictedClassCount == that.otherPredictedClassCount;
352+
&& this.otherPredictedClassDocCount == that.otherPredictedClassDocCount;
348353
}
349354

350355
@Override
351356
public int hashCode() {
352-
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassCount);
357+
return Objects.hash(actualClass, actualClassDocCount, predictedClasses, otherPredictedClassDocCount);
353358
}
354359
}
355360

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/evaluation/classification/MulticlassConfusionMatrixResultTests.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ public static Result createRandom() {
2828
for (int i = 0; i < numClasses; i++) {
2929
List<PredictedClass> predictedClasses = new ArrayList<>(numClasses);
3030
for (int j = 0; j < numClasses; j++) {
31-
if (randomBoolean()) {
32-
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
33-
}
31+
predictedClasses.add(new PredictedClass(classNames.get(j), randomNonNegativeLong()));
3432
}
3533
actualClasses.add(new ActualClass(classNames.get(i), randomNonNegativeLong(), predictedClasses, randomNonNegativeLong()));
3634
}

x-pack/plugin/src/test/resources/rest-api-spec/test/ml/evaluate_data_frame.yml

+5-5
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ setup:
630630
count: 1
631631
- predicted_class: "mouse"
632632
count: 0
633-
other_predicted_class_count: 0
633+
other_predicted_class_doc_count: 0
634634
- actual_class: "dog"
635635
actual_class_doc_count: 3
636636
predicted_classes:
@@ -640,7 +640,7 @@ setup:
640640
count: 2
641641
- predicted_class: "mouse"
642642
count: 0
643-
other_predicted_class_count: 0
643+
other_predicted_class_doc_count: 0
644644
- actual_class: "mouse"
645645
actual_class_doc_count: 2
646646
predicted_classes:
@@ -650,7 +650,7 @@ setup:
650650
count: 0
651651
- predicted_class: "mouse"
652652
count: 1
653-
other_predicted_class_count: 0
653+
other_predicted_class_doc_count: 0
654654
other_actual_class_count: 0
655655
---
656656
"Test classification multiclass_confusion_matrix with explicit size":
@@ -678,15 +678,15 @@ setup:
678678
count: 2
679679
- predicted_class: "dog"
680680
count: 1
681-
other_predicted_class_count: 0
681+
other_predicted_class_doc_count: 0
682682
- actual_class: "dog"
683683
actual_class_doc_count: 3
684684
predicted_classes:
685685
- predicted_class: "cat"
686686
count: 1
687687
- predicted_class: "dog"
688688
count: 2
689-
other_predicted_class_count: 0
689+
other_predicted_class_doc_count: 0
690690
other_actual_class_count: 1
691691
---
692692
"Test classification with null metrics":

0 commit comments

Comments
 (0)