Skip to content

Commit 17b904d

Browse files
[7.x][ML] Decouple DFA progress testing from analyses phases (#55925) (#56024)
This refactors native integ tests to assert progress without expecting explicit phases for analyses. We can test those with yaml tests in a single place. Backport of #55925
1 parent 273ff6a commit 17b904d

File tree

5 files changed

+58
-48
lines changed

5 files changed

+58
-48
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
103103
putAnalytics(config);
104104

105105
assertIsStopped(jobId);
106-
assertProgress(jobId, 0, 0, 0, 0);
106+
assertProgressIsZero(jobId);
107107

108108
startAnalytics(jobId);
109109
waitUntilAnalyticsIsStopped(jobId);
@@ -121,7 +121,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
121121
assertThat(importanceArray, hasSize(greaterThan(0)));
122122
}
123123

124-
assertProgress(jobId, 100, 100, 100, 100);
124+
assertProgressComplete(jobId);
125125
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
126126
assertModelStatePersisted(stateDocId());
127127
assertInferenceModelPersisted(jobId);
@@ -150,7 +150,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
150150
putAnalytics(config);
151151

152152
assertIsStopped(jobId);
153-
assertProgress(jobId, 0, 0, 0, 0);
153+
assertProgressIsZero(jobId);
154154

155155
startAnalytics(jobId);
156156
waitUntilAnalyticsIsStopped(jobId);
@@ -171,7 +171,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
171171
assertThat(stats.getDataCounts().getTestDocsCount(), equalTo(0L));
172172
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
173173

174-
assertProgress(jobId, 100, 100, 100, 100);
174+
assertProgressComplete(jobId);
175175
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
176176
assertModelStatePersisted(stateDocId());
177177
assertInferenceModelPersisted(jobId);
@@ -210,7 +210,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
210210
putAnalytics(config);
211211

212212
assertIsStopped(jobId);
213-
assertProgress(jobId, 0, 0, 0, 0);
213+
assertProgressIsZero(jobId);
214214

215215
startAnalytics(jobId);
216216
waitUntilAnalyticsIsStopped(jobId);
@@ -245,7 +245,7 @@ public <T> void testWithOnlyTrainingRowsAndTrainingPercentIsFifty(String jobId,
245245
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(300L));
246246
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
247247

248-
assertProgress(jobId, 100, 100, 100, 100);
248+
assertProgressComplete(jobId);
249249
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
250250
assertModelStatePersisted(stateDocId());
251251
assertInferenceModelPersisted(jobId);
@@ -305,7 +305,7 @@ public void testStopAndRestart() throws Exception {
305305
putAnalytics(config);
306306

307307
assertIsStopped(jobId);
308-
assertProgress(jobId, 0, 0, 0, 0);
308+
assertProgressIsZero(jobId);
309309

310310
NodeAcknowledgedResponse response = startAnalytics(jobId);
311311
assertThat(response.getNode(), not(emptyString()));
@@ -346,7 +346,7 @@ public void testStopAndRestart() throws Exception {
346346
assertTopClasses(resultsObject, 2, KEYWORD_FIELD, KEYWORD_FIELD_VALUES);
347347
}
348348

349-
assertProgress(jobId, 100, 100, 100, 100);
349+
assertProgressComplete(jobId);
350350
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
351351
assertModelStatePersisted(stateDocId());
352352
assertInferenceModelPersisted(jobId);
@@ -394,7 +394,7 @@ public void testDependentVariableCardinalityTooHighButWithQueryMakesItWithinRang
394394
startAnalytics(jobId);
395395
waitUntilAnalyticsIsStopped(jobId);
396396

397-
assertProgress(jobId, 100, 100, 100, 100);
397+
assertProgressComplete(jobId);
398398
}
399399

400400
public void testDependentVariableIsNested() throws Exception {
@@ -407,7 +407,7 @@ public void testDependentVariableIsNested() throws Exception {
407407
startAnalytics(jobId);
408408
waitUntilAnalyticsIsStopped(jobId);
409409

410-
assertProgress(jobId, 100, 100, 100, 100);
410+
assertProgressComplete(jobId);
411411
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
412412
assertModelStatePersisted(stateDocId());
413413
assertInferenceModelPersisted(jobId);
@@ -425,7 +425,7 @@ public void testDependentVariableIsAliasToKeyword() throws Exception {
425425
startAnalytics(jobId);
426426
waitUntilAnalyticsIsStopped(jobId);
427427

428-
assertProgress(jobId, 100, 100, 100, 100);
428+
assertProgressComplete(jobId);
429429
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
430430
assertModelStatePersisted(stateDocId());
431431
assertInferenceModelPersisted(jobId);
@@ -443,7 +443,7 @@ public void testDependentVariableIsAliasToNested() throws Exception {
443443
startAnalytics(jobId);
444444
waitUntilAnalyticsIsStopped(jobId);
445445

446-
assertProgress(jobId, 100, 100, 100, 100);
446+
assertProgressComplete(jobId);
447447
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
448448
assertModelStatePersisted(stateDocId());
449449
assertInferenceModelPersisted(jobId);
@@ -539,7 +539,7 @@ public void testSetUpgradeMode_ExistingTaskGetsUnassigned() throws Exception {
539539
});
540540

541541
waitUntilAnalyticsIsStopped(jobId);
542-
assertProgress(jobId, 100, 100, 100, 100);
542+
assertProgressComplete(jobId);
543543
}
544544

545545
public void testSetUpgradeMode_NewTaskDoesNotStart() throws Exception {
@@ -572,7 +572,7 @@ public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
572572
startAnalytics(jobId);
573573
waitUntilAnalyticsIsStopped(jobId);
574574

575-
assertProgress(jobId, 100, 100, 100, 100);
575+
assertProgressComplete(jobId);
576576
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
577577
assertModelStatePersisted(stateDocId());
578578
assertInferenceModelPersisted(jobId);

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import static org.hamcrest.Matchers.anyOf;
6868
import static org.hamcrest.Matchers.arrayWithSize;
6969
import static org.hamcrest.Matchers.equalTo;
70+
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
7071
import static org.hamcrest.Matchers.hasItems;
7172
import static org.hamcrest.Matchers.hasSize;
7273
import static org.hamcrest.Matchers.is;
@@ -199,19 +200,28 @@ protected void assertIsStopped(String id) {
199200
assertThat("Stats were: " + Strings.toString(stats), stats.getState(), equalTo(DataFrameAnalyticsState.STOPPED));
200201
}
201202

202-
protected void assertProgress(String id, int reindexing, int loadingData, int analyzing, int writingResults) {
203+
protected void assertProgressIsZero(String id) {
204+
List<PhaseProgress> progress = getProgress(id);
205+
assertThat("progress is not all zero: " + progress,
206+
progress.stream().allMatch(phaseProgress -> phaseProgress.getProgressPercent() == 0), is(true));
207+
}
208+
209+
protected void assertProgressComplete(String id) {
210+
List<PhaseProgress> progress = getProgress(id);
211+
assertThat("progress is complete: " + progress,
212+
progress.stream().allMatch(phaseProgress -> phaseProgress.getProgressPercent() == 100), is(true));
213+
}
214+
215+
private List<PhaseProgress> getProgress(String id) {
203216
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id);
204217
assertThat(stats.getId(), equalTo(id));
205218
List<PhaseProgress> progress = stats.getProgress();
206-
assertThat(progress, hasSize(4));
219+
// We should have at least 4 phases: reindexing, loading_data, writing_results, plus at least one for the analysis
220+
assertThat(progress.size(), greaterThanOrEqualTo(4));
207221
assertThat(progress.get(0).getPhase(), equalTo("reindexing"));
208222
assertThat(progress.get(1).getPhase(), equalTo("loading_data"));
209-
assertThat(progress.get(2).getPhase(), equalTo("analyzing"));
210-
assertThat(progress.get(3).getPhase(), equalTo("writing_results"));
211-
assertThat(progress.get(0).getProgressPercent(), equalTo(reindexing));
212-
assertThat(progress.get(1).getProgressPercent(), equalTo(loadingData));
213-
assertThat(progress.get(2).getProgressPercent(), equalTo(analyzing));
214-
assertThat(progress.get(3).getProgressPercent(), equalTo(writingResults));
223+
assertThat(progress.get(progress.size() - 1).getPhase(), equalTo("writing_results"));
224+
return progress;
215225
}
216226

217227
protected SearchResponse searchStoredProgress(String jobId) {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ public void testMissingFields() throws Exception {
7474
putAnalytics(config);
7575

7676
assertIsStopped(id);
77-
assertProgress(id, 0, 0, 0, 0);
77+
assertProgressIsZero(id);
7878

7979
startAnalytics(id);
8080
waitUntilAnalyticsIsStopped(id);
@@ -108,7 +108,7 @@ public void testMissingFields() throws Exception {
108108
}
109109
}
110110

111-
assertProgress(id, 100, 100, 100, 100);
111+
assertProgressComplete(id);
112112
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
113113
}
114114
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
7272
putAnalytics(config);
7373

7474
assertIsStopped(jobId);
75-
assertProgress(jobId, 0, 0, 0, 0);
75+
assertProgressIsZero(jobId);
7676

7777
startAnalytics(jobId);
7878
waitUntilAnalyticsIsStopped(jobId);
@@ -101,7 +101,7 @@ public void testSingleNumericFeatureAndMixedTrainingAndNonTrainingRows() throws
101101
isPresent());
102102
}
103103

104-
assertProgress(jobId, 100, 100, 100, 100);
104+
assertProgressComplete(jobId);
105105
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
106106
assertModelStatePersisted(stateDocId());
107107
assertInferenceModelPersisted(jobId);
@@ -129,7 +129,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
129129
putAnalytics(config);
130130

131131
assertIsStopped(jobId);
132-
assertProgress(jobId, 0, 0, 0, 0);
132+
assertProgressIsZero(jobId);
133133

134134
startAnalytics(jobId);
135135
waitUntilAnalyticsIsStopped(jobId);
@@ -143,7 +143,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsHundred() throws Excepti
143143
assertThat(resultsObject.get("is_training"), is(true));
144144
}
145145

146-
assertProgress(jobId, 100, 100, 100, 100);
146+
assertProgressComplete(jobId);
147147
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
148148

149149
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(jobId);
@@ -184,7 +184,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
184184
putAnalytics(config);
185185

186186
assertIsStopped(jobId);
187-
assertProgress(jobId, 0, 0, 0, 0);
187+
assertProgressIsZero(jobId);
188188

189189
startAnalytics(jobId);
190190
waitUntilAnalyticsIsStopped(jobId);
@@ -215,7 +215,7 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
215215
assertThat(stats.getDataCounts().getTestDocsCount(), lessThan(350L));
216216
assertThat(stats.getDataCounts().getSkippedDocsCount(), equalTo(0L));
217217

218-
assertProgress(jobId, 100, 100, 100, 100);
218+
assertProgressComplete(jobId);
219219
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
220220
assertModelStatePersisted(stateDocId());
221221
assertInferenceModelPersisted(jobId);
@@ -243,7 +243,7 @@ public void testStopAndRestart() throws Exception {
243243
putAnalytics(config);
244244

245245
assertIsStopped(jobId);
246-
assertProgress(jobId, 0, 0, 0, 0);
246+
assertProgressIsZero(jobId);
247247

248248
NodeAcknowledgedResponse response = startAnalytics(jobId);
249249
assertThat(response.getNode(), not(emptyString()));
@@ -284,7 +284,7 @@ public void testStopAndRestart() throws Exception {
284284
assertThat(resultsObject.get("is_training"), is(true));
285285
}
286286

287-
assertProgress(jobId, 100, 100, 100, 100);
287+
assertProgressComplete(jobId);
288288
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
289289
assertModelStatePersisted(stateDocId());
290290
assertInferenceModelPersisted(jobId);
@@ -342,7 +342,7 @@ public void testDeleteExpiredData_RemovesUnusedState() throws Exception {
342342
startAnalytics(jobId);
343343
waitUntilAnalyticsIsStopped(jobId);
344344

345-
assertProgress(jobId, 100, 100, 100, 100);
345+
assertProgressComplete(jobId);
346346
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
347347
assertModelStatePersisted(stateDocId());
348348
assertInferenceModelPersisted(jobId);
@@ -380,11 +380,11 @@ public void testDependentVariableIsLong() throws Exception {
380380
putAnalytics(config);
381381

382382
assertIsStopped(jobId);
383-
assertProgress(jobId, 0, 0, 0, 0);
383+
assertProgressIsZero(jobId);
384384

385385
startAnalytics(jobId);
386386
waitUntilAnalyticsIsStopped(jobId);
387-
assertProgress(jobId, 100, 100, 100, 100);
387+
assertProgressComplete(jobId);
388388

389389
assertMlResultsFieldMappings(destIndex, predictedClassField, "double");
390390
}

0 commit comments

Comments
 (0)