Skip to content

Commit f47da1d

Browse files
[ML] Restore analytics state if available (#47128)
This commit restores the model state if available in data frame analytics jobs. In addition, this changes the start API so that a stopped job can be restarted. As we now store the progress in the state index when the task is stopped, we can use it to determine what state the job was in when it got stopped. Note that in order to be able to distinguish between a job that runs for the first time and another that is restarting, we ensure reindexing progress is reported to be at least 1 for a running task.
1 parent f99096e commit f47da1d

31 files changed

+709
-153
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsAction.java

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.elasticsearch.action.support.master.MasterNodeRequest;
1414
import org.elasticsearch.client.ElasticsearchClient;
1515
import org.elasticsearch.cluster.metadata.MetaData;
16+
import org.elasticsearch.common.Nullable;
1617
import org.elasticsearch.common.ParseField;
1718
import org.elasticsearch.common.Strings;
1819
import org.elasticsearch.common.io.stream.StreamInput;
@@ -29,8 +30,11 @@
2930
import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsConfig;
3031
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
3132
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
33+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
3234

3335
import java.io.IOException;
36+
import java.util.Collections;
37+
import java.util.List;
3438
import java.util.Objects;
3539

3640
public class StartDataFrameAnalyticsAction extends ActionType<AcknowledgedResponse> {
@@ -150,12 +154,15 @@ public static class TaskParams implements PersistentTaskParams {
150154

151155
public static final Version VERSION_INTRODUCED = Version.V_7_3_0;
152156

157+
private static final ParseField PROGRESS_ON_START = new ParseField("progress_on_start");
158+
153159
public static ConstructingObjectParser<TaskParams, Void> PARSER = new ConstructingObjectParser<>(
154-
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0], (String) a[1]));
160+
MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME, true, a -> new TaskParams((String) a[0], (String) a[1], (List<PhaseProgress>) a[2]));
155161

156162
static {
157163
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.ID);
158164
PARSER.declareString(ConstructingObjectParser.constructorArg(), DataFrameAnalyticsConfig.VERSION);
165+
PARSER.declareObjectArray(ConstructingObjectParser.optionalConstructorArg(), PhaseProgress.PARSER, PROGRESS_ON_START);
159166
}
160167

161168
public static TaskParams fromXContent(XContentParser parser) {
@@ -164,25 +171,36 @@ public static TaskParams fromXContent(XContentParser parser) {
164171

165172
private final String id;
166173
private final Version version;
174+
private final List<PhaseProgress> progressOnStart;
167175

168-
public TaskParams(String id, Version version) {
176+
public TaskParams(String id, Version version, List<PhaseProgress> progressOnStart) {
169177
this.id = Objects.requireNonNull(id);
170178
this.version = Objects.requireNonNull(version);
179+
this.progressOnStart = Collections.unmodifiableList(progressOnStart);
171180
}
172181

173-
private TaskParams(String id, String version) {
174-
this(id, Version.fromString(version));
182+
private TaskParams(String id, String version, @Nullable List<PhaseProgress> progressOnStart) {
183+
this(id, Version.fromString(version), progressOnStart == null ? Collections.emptyList() : progressOnStart);
175184
}
176185

177186
public TaskParams(StreamInput in) throws IOException {
178187
this.id = in.readString();
179188
this.version = Version.readVersion(in);
189+
if (in.getVersion().onOrAfter(Version.V_7_5_0)) {
190+
progressOnStart = in.readList(PhaseProgress::new);
191+
} else {
192+
progressOnStart = Collections.emptyList();
193+
}
180194
}
181195

182196
public String getId() {
183197
return id;
184198
}
185199

200+
public List<PhaseProgress> getProgressOnStart() {
201+
return progressOnStart;
202+
}
203+
186204
@Override
187205
public String getWriteableName() {
188206
return MlTasks.DATA_FRAME_ANALYTICS_TASK_NAME;
@@ -197,20 +215,24 @@ public Version getMinimalSupportedVersion() {
197215
public void writeTo(StreamOutput out) throws IOException {
198216
out.writeString(id);
199217
Version.writeVersion(version, out);
218+
if (out.getVersion().onOrAfter(Version.V_7_5_0)) {
219+
out.writeList(progressOnStart);
220+
}
200221
}
201222

202223
@Override
203224
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
204225
builder.startObject();
205226
builder.field(DataFrameAnalyticsConfig.ID.getPreferredName(), id);
206227
builder.field(DataFrameAnalyticsConfig.VERSION.getPreferredName(), version);
228+
builder.field(PROGRESS_ON_START.getPreferredName(), progressOnStart);
207229
builder.endObject();
208230
return builder;
209231
}
210232

211233
@Override
212234
public int hashCode() {
213-
return Objects.hash(id, version);
235+
return Objects.hash(id, version, progressOnStart);
214236
}
215237

216238
@Override
@@ -219,7 +241,9 @@ public boolean equals(Object o) {
219241
if (o == null || getClass() != o.getClass()) return false;
220242

221243
TaskParams other = (TaskParams) o;
222-
return Objects.equals(id, other.id) && Objects.equals(version, other.version);
244+
return Objects.equals(id, other.id)
245+
&& Objects.equals(version, other.version)
246+
&& Objects.equals(progressOnStart, other.progressOnStart);
223247
}
224248
}
225249

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,9 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
3737
* @return {@code true} if this analysis persists state that can later be used to restore from a given point
3838
*/
3939
boolean persistsState();
40+
41+
/**
42+
* Returns the document id for the analysis state
43+
*/
44+
String getStateDocId(String jobId);
4045
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ public boolean persistsState() {
174174
return false;
175175
}
176176

177+
@Override
178+
public String getStateDocId(String jobId) {
179+
throw new UnsupportedOperationException("Outlier detection does not support state");
180+
}
181+
177182
public enum Method {
178183
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;
179184

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,11 @@ public boolean persistsState() {
215215
return true;
216216
}
217217

218+
@Override
219+
public String getStateDocId(String jobId) {
220+
return jobId + "_regression_state#1";
221+
}
222+
218223
@Override
219224
public int hashCode() {
220225
return Objects.hash(dependentVariable, lambda, gamma, eta, maximumNumberTrees, featureBagFraction, predictionFieldName,

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/messages/Messages.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ public final class Messages {
6666
public static final String DATA_FRAME_ANALYTICS_AUDIT_REUSING_DEST_INDEX = "Using existing destination index [{0}]";
6767
public static final String DATA_FRAME_ANALYTICS_AUDIT_FINISHED_REINDEXING = "Finished reindexing to destination index [{0}]";
6868
public static final String DATA_FRAME_ANALYTICS_AUDIT_FINISHED_ANALYSIS = "Finished analysis";
69+
public static final String DATA_FRAME_ANALYTICS_AUDIT_RESTORING_STATE = "Restoring from previous model state";
6970

7071
public static final String FILTER_CANNOT_DELETE = "Cannot delete filter [{0}] currently used by jobs {1}";
7172
public static final String FILTER_CONTAINS_TOO_MANY_ITEMS = "Filter [{0}] contains too many items; up to [{1}] items are allowed";

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartDataFrameAnalyticsActionTaskParamsTests.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@
1010
import org.elasticsearch.common.io.stream.Writeable;
1111
import org.elasticsearch.common.xcontent.XContentParser;
1212
import org.elasticsearch.test.AbstractSerializingTestCase;
13+
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
1314

1415
import java.io.IOException;
16+
import java.util.ArrayList;
17+
import java.util.List;
1518

1619
public class StartDataFrameAnalyticsActionTaskParamsTests extends AbstractSerializingTestCase<StartDataFrameAnalyticsAction.TaskParams> {
1720

@@ -22,7 +25,12 @@ protected StartDataFrameAnalyticsAction.TaskParams doParseInstance(XContentParse
2225

2326
@Override
2427
protected StartDataFrameAnalyticsAction.TaskParams createTestInstance() {
25-
return new StartDataFrameAnalyticsAction.TaskParams(randomAlphaOfLength(10), Version.CURRENT);
28+
int phaseCount = randomIntBetween(0, 5);
29+
List<PhaseProgress> progressOnStart = new ArrayList<>(phaseCount);
30+
for (int i = 0; i < phaseCount; i++) {
31+
progressOnStart.add(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)));
32+
}
33+
return new StartDataFrameAnalyticsAction.TaskParams(randomAlphaOfLength(10), Version.CURRENT, progressOnStart);
2634
}
2735

2836
@Override

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetectionTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,10 @@ public void testGetParams_GivenExplicitValues() {
5656
assertThat((Double) params.get(OutlierDetection.FEATURE_INFLUENCE_THRESHOLD.getPreferredName()),
5757
is(closeTo(0.42, 1E-9)));
5858
}
59+
60+
public void testGetStateDocId() {
61+
OutlierDetection outlierDetection = createRandom();
62+
assertThat(outlierDetection.persistsState(), is(false));
63+
expectThrows(UnsupportedOperationException.class, () -> outlierDetection.getStateDocId("foo"));
64+
}
5965
}

x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/RegressionTests.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.io.IOException;
1414

1515
import static org.hamcrest.Matchers.equalTo;
16+
import static org.hamcrest.Matchers.is;
1617

1718
public class RegressionTests extends AbstractSerializingTestCase<Regression> {
1819

@@ -124,4 +125,11 @@ public void testRegression_GivenTrainingPercentIsGreaterThan100() {
124125

125126
assertThat(e.getMessage(), equalTo("[training_percent] must be a double in [1, 100]"));
126127
}
128+
129+
public void testGetStateDocId() {
130+
Regression regression = createRandom();
131+
assertThat(regression.persistsState(), is(true));
132+
String randomId = randomAlphaOfLength(10);
133+
assertThat(regression.getStateDocId(randomId), equalTo(randomId + "_regression_state#1"));
134+
}
127135
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,8 @@ protected static void assertThatAuditMessagesMatch(String configId, String... ex
199199
assertBusy(() -> assertTrue(indexExists(AuditorField.NOTIFICATIONS_INDEX)));
200200
assertBusy(() -> {
201201
String[] actualAuditMessages = fetchAllAuditMessages(configId);
202-
assertThat(actualAuditMessages.length, equalTo(expectedAuditMessagePrefixes.length));
202+
assertThat("Messages: " + Arrays.toString(actualAuditMessages), actualAuditMessages.length,
203+
equalTo(expectedAuditMessagePrefixes.length));
203204
for (int i = 0; i < actualAuditMessages.length; i++) {
204205
assertThat(actualAuditMessages[i], startsWith(expectedAuditMessagePrefixes[i]));
205206
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.List;
2424
import java.util.Map;
2525

26+
import static org.hamcrest.Matchers.anyOf;
2627
import static org.hamcrest.Matchers.equalTo;
2728
import static org.hamcrest.Matchers.greaterThan;
2829
import static org.hamcrest.Matchers.is;
@@ -258,6 +259,87 @@ public void testWithOnlyTrainingRowsAndTrainingPercentIsFifty() throws Exception
258259
assertModelStatePersisted(jobId);
259260
}
260261

262+
public void testStopAndRestart() throws Exception {
263+
String jobId = "regression_stop_and_restart";
264+
String sourceIndex = jobId + "_source_index";
265+
266+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
267+
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
268+
269+
List<Double> featureValues = Arrays.asList(1.0, 2.0, 3.0);
270+
List<Double> dependentVariableValues = Arrays.asList(10.0, 20.0, 30.0);
271+
272+
for (int i = 0; i < 350; i++) {
273+
Double field = featureValues.get(i % 3);
274+
Double value = dependentVariableValues.get(i % 3);
275+
276+
IndexRequest indexRequest = new IndexRequest(sourceIndex);
277+
indexRequest.source("feature", field, "variable", value);
278+
bulkRequestBuilder.add(indexRequest);
279+
}
280+
BulkResponse bulkResponse = bulkRequestBuilder.get();
281+
if (bulkResponse.hasFailures()) {
282+
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
283+
}
284+
285+
String destIndex = sourceIndex + "_results";
286+
DataFrameAnalyticsConfig config = buildRegressionAnalytics(jobId, new String[] {sourceIndex}, destIndex, null,
287+
new Regression("variable"));
288+
registerAnalytics(config);
289+
putAnalytics(config);
290+
291+
assertState(jobId, DataFrameAnalyticsState.STOPPED);
292+
assertProgress(jobId, 0, 0, 0, 0);
293+
294+
startAnalytics(jobId);
295+
296+
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
297+
assertBusy(() -> {
298+
DataFrameAnalyticsState state = getAnalyticsStats(jobId).get(0).getState();
299+
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
300+
equalTo(DataFrameAnalyticsState.STOPPED))));
301+
});
302+
stopAnalytics(jobId);
303+
waitUntilAnalyticsIsStopped(jobId);
304+
305+
// Now let's start it again
306+
try {
307+
startAnalytics(jobId);
308+
} catch (Exception e) {
309+
if (e.getMessage().equals("Cannot start because the job has already finished")) {
310+
// That means the job had managed to complete
311+
} else {
312+
throw e;
313+
}
314+
}
315+
316+
waitUntilAnalyticsIsStopped(jobId);
317+
318+
SearchResponse sourceData = client().prepareSearch(sourceIndex).setTrackTotalHits(true).setSize(1000).get();
319+
for (SearchHit hit : sourceData.getHits()) {
320+
GetResponse destDocGetResponse = client().prepareGet().setIndex(config.getDest().getIndex()).setId(hit.getId()).get();
321+
assertThat(destDocGetResponse.isExists(), is(true));
322+
Map<String, Object> sourceDoc = hit.getSourceAsMap();
323+
Map<String, Object> destDoc = destDocGetResponse.getSource();
324+
for (String field : sourceDoc.keySet()) {
325+
assertThat(destDoc.containsKey(field), is(true));
326+
assertThat(destDoc.get(field), equalTo(sourceDoc.get(field)));
327+
}
328+
assertThat(destDoc.containsKey("ml"), is(true));
329+
330+
@SuppressWarnings("unchecked")
331+
Map<String, Object> resultsObject = (Map<String, Object>) destDoc.get("ml");
332+
333+
assertThat(resultsObject.containsKey("variable_prediction"), is(true));
334+
assertThat(resultsObject.containsKey("is_training"), is(true));
335+
assertThat(resultsObject.get("is_training"), is(true));
336+
}
337+
338+
assertProgress(jobId, 100, 100, 100, 100);
339+
assertThat(searchStoredProgress(jobId).getHits().getTotalHits().value, equalTo(1L));
340+
assertModelStatePersisted(jobId);
341+
}
342+
261343
private void assertModelStatePersisted(String jobId) {
262344
String docId = jobId + "_regression_state#1";
263345
SearchResponse searchResponse = client().prepareSearch(AnomalyDetectorsIndex.jobStateIndexPattern())

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import java.util.Map;
3232

3333
import static org.hamcrest.Matchers.allOf;
34+
import static org.hamcrest.Matchers.anyOf;
3435
import static org.hamcrest.Matchers.equalTo;
3536
import static org.hamcrest.Matchers.greaterThan;
3637
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
@@ -476,4 +477,70 @@ public void testModelMemoryLimitLowerThanEstimatedMemoryUsage() throws Exception
476477
"Created analytics with analysis type [outlier_detection]",
477478
"Estimated memory usage for this analytics to be");
478479
}
480+
481+
public void testOutlierDetectionStopAndRestart() throws Exception {
482+
String sourceIndex = "test-outlier-detection-stop-and-restart";
483+
484+
client().admin().indices().prepareCreate(sourceIndex)
485+
.addMapping("_doc", "numeric_1", "type=double", "numeric_2", "type=float", "categorical_1", "type=keyword")
486+
.get();
487+
488+
BulkRequestBuilder bulkRequestBuilder = client().prepareBulk();
489+
bulkRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
490+
491+
int docCount = randomIntBetween(1024, 2048);
492+
for (int i = 0; i < docCount; i++) {
493+
IndexRequest indexRequest = new IndexRequest(sourceIndex);
494+
indexRequest.source("numeric_1", randomDouble(), "numeric_2", randomFloat(), "categorical_1", randomAlphaOfLength(10));
495+
bulkRequestBuilder.add(indexRequest);
496+
}
497+
BulkResponse bulkResponse = bulkRequestBuilder.get();
498+
if (bulkResponse.hasFailures()) {
499+
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
500+
}
501+
502+
String id = "test_outlier_detection_stop_and_restart";
503+
DataFrameAnalyticsConfig config = buildOutlierDetectionAnalytics(
504+
id, new String[] {sourceIndex}, sourceIndex + "-results", "custom_ml");
505+
registerAnalytics(config);
506+
putAnalytics(config);
507+
508+
assertState(id, DataFrameAnalyticsState.STOPPED);
509+
startAnalytics(id);
510+
511+
// Wait until state is one of REINDEXING or ANALYZING, or until it is STOPPED.
512+
assertBusy(() -> {
513+
DataFrameAnalyticsState state = getAnalyticsStats(id).get(0).getState();
514+
assertThat(state, is(anyOf(equalTo(DataFrameAnalyticsState.REINDEXING), equalTo(DataFrameAnalyticsState.ANALYZING),
515+
equalTo(DataFrameAnalyticsState.STOPPED))));
516+
});
517+
stopAnalytics(id);
518+
waitUntilAnalyticsIsStopped(id);
519+
520+
// Now let's start it again
521+
try {
522+
startAnalytics(id);
523+
} catch (Exception e) {
524+
if (e.getMessage().equals("Cannot start because the job has already finished")) {
525+
// That means the job had managed to complete
526+
} else {
527+
throw e;
528+
}
529+
}
530+
531+
waitUntilAnalyticsIsStopped(id);
532+
533+
// Check we've got all docs
534+
SearchResponse searchResponse = client().prepareSearch(config.getDest().getIndex()).setTrackTotalHits(true).get();
535+
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
536+
537+
// Check they all have an outlier_score
538+
searchResponse = client().prepareSearch(config.getDest().getIndex())
539+
.setTrackTotalHits(true)
540+
.setQuery(QueryBuilders.existsQuery("custom_ml.outlier_score")).get();
541+
assertThat(searchResponse.getHits().getTotalHits().value, equalTo((long) docCount));
542+
543+
assertProgress(id, 100, 100, 100, 100);
544+
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
545+
}
479546
}

0 commit comments

Comments
 (0)