Skip to content

Commit e6cbcf7

Browse files
[7.x] [ML] Persist/restore state for DFA classification (#50040) (#50147)
This commit adds state persist/restore for data frame analytics classification jobs. Backport of #50040
1 parent 1c3ce11 commit e6cbcf7

File tree

5 files changed

+29
-14
lines changed

5 files changed

+29
-14
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -253,12 +253,12 @@ public boolean supportsMissingValues() {
253253

254254
@Override
255255
public boolean persistsState() {
256-
return false;
256+
return true;
257257
}
258258

259259
@Override
260260
public String getStateDocId(String jobId) {
261-
throw new UnsupportedOperationException();
261+
return jobId + "_classification_state#1";
262262
}
263263

264264
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/ClassificationTests.java

+7
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,11 @@ public void testToXContent_GivenEmptyParams() throws IOException {
208208
assertThat(json, containsString("randomize_seed"));
209209
}
210210
}
211+
212+
public void testGetStateDocId() {
213+
Classification classification = createRandom();
214+
assertThat(classification.persistsState(), is(true));
215+
String randomId = randomAlphaOfLength(10);
216+
assertThat(classification.getStateDocId(randomId), equalTo(randomId + "_classification_state#1"));
217+
}
211218
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

+7
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
9595

9696
assertProgress(jobId, 100, 100, 100, 100);
9797
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
98+
assertModelStatePersisted(stateDocId());
9899
assertInferenceModelPersisted(jobId);
99100
assertThatAuditMessagesMatch(jobId,
100101
"Created analytics with analysis type [classification]",
@@ -135,6 +136,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
135136

136137
assertProgress(jobId, 100, 100, 100, 100);
137138
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
139+
assertModelStatePersisted(stateDocId());
138140
assertInferenceModelPersisted(jobId);
139141
assertThatAuditMessagesMatch(jobId,
140142
"Created analytics with analysis type [classification]",
@@ -195,6 +197,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(
195197

196198
assertProgress(jobId, 100, 100, 100, 100);
197199
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
200+
assertModelStatePersisted(stateDocId());
198201
assertInferenceModelPersisted(jobId);
199202
assertThatAuditMessagesMatch(jobId,
200203
"Created analytics with analysis type [classification]",
@@ -447,4 +450,8 @@ private <T> void assertEvaluation(String dependentVariable, List<T> dependentVar
447450
assertThat(confusionMatrixResult.getOtherActualClassCount(), equalTo(0L));
448451
}
449452
}
453+
454+
protected String stateDocId() {
455+
return jobId + "_classification_state#1";
456+
}
450457
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

+7
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,11 @@ protected static Set<String> getTrainingRowsIds(String index) {
274274
assertThat(trainingRowsIds.isEmpty(), is(false));
275275
return trainingRowsIds;
276276
}
277+
278+
protected static void assertModelStatePersisted(String stateDocId) {
279+
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
280+
.setQuery(QueryBuilders.idsQuery().addIds(stateDocId))
281+
.get();
282+
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
283+
}
277284
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

+6-12
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,12 @@
1212
import org.elasticsearch.action.search.SearchResponse;
1313
import org.elasticsearch.action.support.WriteRequest;
1414
import org.elasticsearch.common.unit.TimeValue;
15-
import org.elasticsearch.index.query.QueryBuilders;
1615
import org.elasticsearch.search.SearchHit;
1716
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
1817
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsState;
1918
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParams;
2019
import org.elasticsearch.xpack.core.ml.dataframe.analyses.BoostedTreeParamsTests;
2120
import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression;
22-
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
2321
import org.junit.After;
2422

2523
import java.util.Arrays;
@@ -82,7 +80,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
8280

8381
assertProgress(jobId, 100, 100, 100, 100);
8482
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
85-
assertModelStatePersisted(jobId);
83+
assertModelStatePersisted(stateDocId());
8684
assertInferenceModelPersisted(jobId);
8785
assertThatAuditMessagesMatch(jobId,
8886
"Created analytics with analysis type [regression]",
@@ -119,7 +117,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
119117

120118
assertProgress(jobId, 100, 100, 100, 100);
121119
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
122-
assertModelStatePersisted(jobId);
120+
assertModelStatePersisted(stateDocId());
123121
assertInferenceModelPersisted(jobId);
124122
assertThatAuditMessagesMatch(jobId,
125123
"Created analytics with analysis type [regression]",
@@ -171,7 +169,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
171169

172170
assertProgress(jobId, 100, 100, 100, 100);
173171
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
174-
assertModelStatePersisted(jobId);
172+
assertModelStatePersisted(stateDocId());
175173
assertInferenceModelPersisted(jobId);
176174
assertThatAuditMessagesMatch(jobId,
177175
"Created analytics with analysis type [regression]",
@@ -233,7 +231,7 @@ public void testStopAndRestart() throws Exception {
233231

234232
assertProgress(jobId, 100, 100, 100, 100);
235233
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
236-
assertModelStatePersisted(jobId);
234+
assertModelStatePersisted(stateDocId());
237235
assertInferenceModelPersisted(jobId);
238236
}
239237

@@ -324,11 +322,7 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
324322
return resultsObject;
325323
}
326324

327-
private static void assertModelStatePersisted(String jobId) {
328-
String docId = jobId + "_regression_state#1";
329-
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())
330-
.setQuery(QueryBuilders.idsQuery().addIds(docId))
331-
.get();
332-
assertThat(searchResponse.getHits().getHits().length, equalTo(1));
325+
protected String stateDocId() {
326+
return jobId + "_regression_state#1";
333327
}
334328
}

0 commit comments

Comments
 (0)