Skip to content

Commit 39785eb

Browse files
[ML] Data frame analytics data counts (#53998)
This commit instruments data frame analytics with stats for the data that are being analyzed. In particular, we count training docs, test docs, and skipped docs. In order to account docs with missing values as skipped docs for analyses that do not support missing values, this commit changes the extractor so that it only ignores docs with missing values when it collects the data summary, which is used to estimate memory usage.
1 parent 0a35f39 commit 39785eb

File tree

28 files changed

+744
-123
lines changed

28 files changed

+744
-123
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStats.java

+20-7
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import org.elasticsearch.client.ml.NodeAttributes;
2323
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
24+
import org.elasticsearch.client.ml.dataframe.stats.common.DataCounts;
2425
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsage;
2526
import org.elasticsearch.common.Nullable;
2627
import org.elasticsearch.common.ParseField;
@@ -47,6 +48,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws
4748
static final ParseField STATE = new ParseField("state");
4849
static final ParseField FAILURE_REASON = new ParseField("failure_reason");
4950
static final ParseField PROGRESS = new ParseField("progress");
51+
static final ParseField DATA_COUNTS = new ParseField("data_counts");
5052
static final ParseField MEMORY_USAGE = new ParseField("memory_usage");
5153
static final ParseField ANALYSIS_STATS = new ParseField("analysis_stats");
5254
static final ParseField NODE = new ParseField("node");
@@ -60,10 +62,11 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws
6062
(DataFrameAnalyticsState) args[1],
6163
(String) args[2],
6264
(List<PhaseProgress>) args[3],
63-
(MemoryUsage) args[4],
64-
(AnalysisStats) args[5],
65-
(NodeAttributes) args[6],
66-
(String) args[7]));
65+
(DataCounts) args[4],
66+
(MemoryUsage) args[5],
67+
(AnalysisStats) args[6],
68+
(NodeAttributes) args[7],
69+
(String) args[8]));
6770

6871
static {
6972
PARSER.declareString(constructorArg(), ID);
@@ -75,6 +78,7 @@ public static DataFrameAnalyticsStats fromXContent(XContentParser parser) throws
7578
}, STATE, ObjectParser.ValueType.STRING);
7679
PARSER.declareString(optionalConstructorArg(), FAILURE_REASON);
7780
PARSER.declareObjectArray(optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS);
81+
PARSER.declareObject(optionalConstructorArg(), DataCounts.PARSER, DATA_COUNTS);
7882
PARSER.declareObject(optionalConstructorArg(), MemoryUsage.PARSER, MEMORY_USAGE);
7983
PARSER.declareObject(optionalConstructorArg(), (p, c) -> parseAnalysisStats(p), ANALYSIS_STATS);
8084
PARSER.declareObject(optionalConstructorArg(), NodeAttributes.PARSER, NODE);
@@ -93,19 +97,21 @@ private static AnalysisStats parseAnalysisStats(XContentParser parser) throws IO
9397
private final DataFrameAnalyticsState state;
9498
private final String failureReason;
9599
private final List<PhaseProgress> progress;
100+
private final DataCounts dataCounts;
96101
private final MemoryUsage memoryUsage;
97102
private final AnalysisStats analysisStats;
98103
private final NodeAttributes node;
99104
private final String assignmentExplanation;
100105

101106
public DataFrameAnalyticsStats(String id, DataFrameAnalyticsState state, @Nullable String failureReason,
102-
@Nullable List<PhaseProgress> progress, @Nullable MemoryUsage memoryUsage,
103-
@Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
107+
@Nullable List<PhaseProgress> progress, @Nullable DataCounts dataCounts,
108+
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable NodeAttributes node,
104109
@Nullable String assignmentExplanation) {
105110
this.id = id;
106111
this.state = state;
107112
this.failureReason = failureReason;
108113
this.progress = progress;
114+
this.dataCounts = dataCounts;
109115
this.memoryUsage = memoryUsage;
110116
this.analysisStats = analysisStats;
111117
this.node = node;
@@ -128,6 +134,11 @@ public List<PhaseProgress> getProgress() {
128134
return progress;
129135
}
130136

137+
@Nullable
138+
public DataCounts getDataCounts() {
139+
return dataCounts;
140+
}
141+
131142
@Nullable
132143
public MemoryUsage getMemoryUsage() {
133144
return memoryUsage;
@@ -156,6 +167,7 @@ public boolean equals(Object o) {
156167
&& Objects.equals(state, other.state)
157168
&& Objects.equals(failureReason, other.failureReason)
158169
&& Objects.equals(progress, other.progress)
170+
&& Objects.equals(dataCounts, other.dataCounts)
159171
&& Objects.equals(memoryUsage, other.memoryUsage)
160172
&& Objects.equals(analysisStats, other.analysisStats)
161173
&& Objects.equals(node, other.node)
@@ -164,7 +176,7 @@ public boolean equals(Object o) {
164176

165177
@Override
166178
public int hashCode() {
167-
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
179+
return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node, assignmentExplanation);
168180
}
169181

170182
@Override
@@ -174,6 +186,7 @@ public String toString() {
174186
.add("state", state)
175187
.add("failureReason", failureReason)
176188
.add("progress", progress)
189+
.add("dataCounts", dataCounts)
177190
.add("memoryUsage", memoryUsage)
178191
.add("analysisStats", analysisStats)
179192
.add("node", node)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml.dataframe.stats.common;
21+
22+
import org.elasticsearch.common.Nullable;
23+
import org.elasticsearch.common.ParseField;
24+
import org.elasticsearch.common.inject.internal.ToStringBuilder;
25+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
26+
import org.elasticsearch.common.xcontent.ToXContentObject;
27+
import org.elasticsearch.common.xcontent.XContentBuilder;
28+
29+
import java.io.IOException;
30+
import java.util.Objects;
31+
32+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optionalConstructorArg;
33+
34+
public class DataCounts implements ToXContentObject {
35+
36+
public static final String TYPE_VALUE = "analytics_data_counts";
37+
38+
public static final ParseField TRAINING_DOCS_COUNT = new ParseField("training_docs_count");
39+
public static final ParseField TEST_DOCS_COUNT = new ParseField("test_docs_count");
40+
public static final ParseField SKIPPED_DOCS_COUNT = new ParseField("skipped_docs_count");
41+
42+
public static final ConstructingObjectParser<DataCounts, Void> PARSER = new ConstructingObjectParser<>(TYPE_VALUE, true,
43+
a -> {
44+
Long trainingDocsCount = (Long) a[0];
45+
Long testDocsCount = (Long) a[1];
46+
Long skippedDocsCount = (Long) a[2];
47+
return new DataCounts(
48+
getOrDefault(trainingDocsCount, 0L),
49+
getOrDefault(testDocsCount, 0L),
50+
getOrDefault(skippedDocsCount, 0L)
51+
);
52+
});
53+
54+
static {
55+
PARSER.declareLong(optionalConstructorArg(), TRAINING_DOCS_COUNT);
56+
PARSER.declareLong(optionalConstructorArg(), TEST_DOCS_COUNT);
57+
PARSER.declareLong(optionalConstructorArg(), SKIPPED_DOCS_COUNT);
58+
}
59+
60+
private final long trainingDocsCount;
61+
private final long testDocsCount;
62+
private final long skippedDocsCount;
63+
64+
public DataCounts(long trainingDocsCount, long testDocsCount, long skippedDocsCount) {
65+
this.trainingDocsCount = trainingDocsCount;
66+
this.testDocsCount = testDocsCount;
67+
this.skippedDocsCount = skippedDocsCount;
68+
}
69+
70+
@Override
71+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
72+
builder.startObject();
73+
builder.field(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount);
74+
builder.field(TEST_DOCS_COUNT.getPreferredName(), testDocsCount);
75+
builder.field(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount);
76+
builder.endObject();
77+
return builder;
78+
}
79+
80+
@Override
81+
public boolean equals(Object o) {
82+
if (this == o) return true;
83+
if (o == null || getClass() != o.getClass()) return false;
84+
DataCounts that = (DataCounts) o;
85+
return trainingDocsCount == that.trainingDocsCount
86+
&& testDocsCount == that.testDocsCount
87+
&& skippedDocsCount == that.skippedDocsCount;
88+
}
89+
90+
@Override
91+
public int hashCode() {
92+
return Objects.hash(trainingDocsCount, testDocsCount, skippedDocsCount);
93+
}
94+
95+
@Override
96+
public String toString() {
97+
return new ToStringBuilder(getClass())
98+
.add(TRAINING_DOCS_COUNT.getPreferredName(), trainingDocsCount)
99+
.add(TEST_DOCS_COUNT.getPreferredName(), testDocsCount)
100+
.add(SKIPPED_DOCS_COUNT.getPreferredName(), skippedDocsCount)
101+
.toString();
102+
}
103+
104+
public long getTrainingDocsCount() {
105+
return trainingDocsCount;
106+
}
107+
108+
public long getTestDocsCount() {
109+
return testDocsCount;
110+
}
111+
112+
public long getSkippedDocsCount() {
113+
return skippedDocsCount;
114+
}
115+
116+
private static <T> T getOrDefault(@Nullable T value, T defaultValue) {
117+
return value != null ? value : defaultValue;
118+
}
119+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/ml/dataframe/DataFrameAnalyticsStatsTests.java

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStats;
2424
import org.elasticsearch.client.ml.dataframe.stats.AnalysisStatsNamedXContentProvider;
2525
import org.elasticsearch.client.ml.dataframe.stats.classification.ClassificationStatsTests;
26+
import org.elasticsearch.client.ml.dataframe.stats.common.DataCountsTests;
2627
import org.elasticsearch.client.ml.dataframe.stats.common.MemoryUsageTests;
2728
import org.elasticsearch.client.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests;
2829
import org.elasticsearch.client.ml.dataframe.stats.regression.RegressionStatsTests;
@@ -68,6 +69,7 @@ public static DataFrameAnalyticsStats randomDataFrameAnalyticsStats() {
6869
randomFrom(DataFrameAnalyticsState.values()),
6970
randomBoolean() ? null : randomAlphaOfLength(10),
7071
randomBoolean() ? null : createRandomProgress(),
72+
randomBoolean() ? null : DataCountsTests.createRandom(),
7173
randomBoolean() ? null : MemoryUsageTests.createRandom(),
7274
analysisStats,
7375
randomBoolean() ? null : NodeAttributesTests.createRandom(),
@@ -93,6 +95,9 @@ public static void toXContent(DataFrameAnalyticsStats stats, XContentBuilder bui
9395
if (stats.getProgress() != null) {
9496
builder.field(DataFrameAnalyticsStats.PROGRESS.getPreferredName(), stats.getProgress());
9597
}
98+
if (stats.getDataCounts() != null) {
99+
builder.field(DataFrameAnalyticsStats.DATA_COUNTS.getPreferredName(), stats.getDataCounts());
100+
}
96101
if (stats.getMemoryUsage() != null) {
97102
builder.field(DataFrameAnalyticsStats.MEMORY_USAGE.getPreferredName(), stats.getMemoryUsage());
98103
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml.dataframe.stats.common;
21+
22+
import org.elasticsearch.common.xcontent.XContentParser;
23+
import org.elasticsearch.test.AbstractXContentTestCase;
24+
25+
import java.io.IOException;
26+
27+
public class DataCountsTests extends AbstractXContentTestCase<DataCounts> {
28+
29+
@Override
30+
protected DataCounts createTestInstance() {
31+
return createRandom();
32+
}
33+
34+
public static DataCounts createRandom() {
35+
return new DataCounts(
36+
randomNonNegativeLong(),
37+
randomNonNegativeLong(),
38+
randomNonNegativeLong()
39+
);
40+
}
41+
42+
@Override
43+
protected DataCounts doParseInstance(XContentParser parser) throws IOException {
44+
return DataCounts.PARSER.apply(parser, null);
45+
}
46+
47+
@Override
48+
protected boolean supportsUnknownFields() {
49+
return true;
50+
}
51+
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/GetDataFrameAnalyticsStatsAction.java

+26-3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
3030
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
3131
import org.elasticsearch.xpack.core.ml.dataframe.stats.AnalysisStats;
32+
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.DataCounts;
3233
import org.elasticsearch.xpack.core.ml.dataframe.stats.MemoryUsage;
3334
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
3435
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
@@ -165,6 +166,9 @@ public static class Stats implements ToXContentObject, Writeable {
165166
*/
166167
private final List<PhaseProgress> progress;
167168

169+
@Nullable
170+
private final DataCounts dataCounts;
171+
168172
@Nullable
169173
private final MemoryUsage memoryUsage;
170174

@@ -177,12 +181,13 @@ public static class Stats implements ToXContentObject, Writeable {
177181
private final String assignmentExplanation;
178182

179183
public Stats(String id, DataFrameAnalyticsState state, @Nullable String failureReason, List<PhaseProgress> progress,
180-
@Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats, @Nullable DiscoveryNode node,
181-
@Nullable String assignmentExplanation) {
184+
@Nullable DataCounts dataCounts, @Nullable MemoryUsage memoryUsage, @Nullable AnalysisStats analysisStats,
185+
@Nullable DiscoveryNode node, @Nullable String assignmentExplanation) {
182186
this.id = Objects.requireNonNull(id);
183187
this.state = Objects.requireNonNull(state);
184188
this.failureReason = failureReason;
185189
this.progress = Objects.requireNonNull(progress);
190+
this.dataCounts = dataCounts;
186191
this.memoryUsage = memoryUsage;
187192
this.analysisStats = analysisStats;
188193
this.node = node;
@@ -198,6 +203,11 @@ public Stats(StreamInput in) throws IOException {
198203
} else {
199204
progress = in.readList(PhaseProgress::new);
200205
}
206+
if (in.getVersion().onOrAfter(Version.V_8_0_0)) {
207+
dataCounts = in.readOptionalWriteable(DataCounts::new);
208+
} else {
209+
dataCounts = null;
210+
}
201211
if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
202212
memoryUsage = in.readOptionalWriteable(MemoryUsage::new);
203213
} else {
@@ -261,6 +271,11 @@ public List<PhaseProgress> getProgress() {
261271
return progress;
262272
}
263273

274+
@Nullable
275+
public DataCounts getDataCounts() {
276+
return dataCounts;
277+
}
278+
264279
@Nullable
265280
public MemoryUsage getMemoryUsage() {
266281
return memoryUsage;
@@ -293,6 +308,9 @@ public XContentBuilder toUnwrappedXContent(XContentBuilder builder) throws IOExc
293308
if (progress != null) {
294309
builder.field("progress", progress);
295310
}
311+
if (dataCounts != null) {
312+
builder.field("data_counts", dataCounts);
313+
}
296314
if (memoryUsage != null) {
297315
builder.field("memory_usage", memoryUsage);
298316
}
@@ -331,6 +349,9 @@ public void writeTo(StreamOutput out) throws IOException {
331349
} else {
332350
out.writeList(progress);
333351
}
352+
if (out.getVersion().onOrAfter(Version.V_8_0_0)) {
353+
out.writeOptionalWriteable(dataCounts);
354+
}
334355
if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
335356
out.writeOptionalWriteable(memoryUsage);
336357
}
@@ -369,7 +390,8 @@ private void writeProgressToLegacy(StreamOutput out) throws IOException {
369390

370391
@Override
371392
public int hashCode() {
372-
return Objects.hash(id, state, failureReason, progress, memoryUsage, analysisStats, node, assignmentExplanation);
393+
return Objects.hash(id, state, failureReason, progress, dataCounts, memoryUsage, analysisStats, node,
394+
assignmentExplanation);
373395
}
374396

375397
@Override
@@ -385,6 +407,7 @@ public boolean equals(Object obj) {
385407
&& Objects.equals(this.state, other.state)
386408
&& Objects.equals(this.failureReason, other.failureReason)
387409
&& Objects.equals(this.progress, other.progress)
410+
&& Objects.equals(this.dataCounts, other.dataCounts)
388411
&& Objects.equals(this.memoryUsage, other.memoryUsage)
389412
&& Objects.equals(this.analysisStats, other.analysisStats)
390413
&& Objects.equals(this.node, other.node)

0 commit comments

Comments
 (0)