Skip to content

Commit dd50520

Browse files
authored
[ML] Fix snapshot upgrader so that if state is not fully written or parseable the task fails (#65755)
It is possible that snapshot upgrader execution path continues before the old model state is fully read by the native process. To prevent this, a flush request is made after the state is loaded. This is to verify that the all the state has been read by the native process. This allows the task to fail if reading the state fails and prevents some strange race conditions. closes #65699
1 parent ccad78e commit dd50520

File tree

4 files changed

+89
-13
lines changed

4 files changed

+89
-13
lines changed

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+2-7
Original file line numberDiff line numberDiff line change
@@ -2730,10 +2730,6 @@ private String createAndPutDatafeed(String jobId, String indexName) throws IOExc
27302730
}
27312731

27322732
public void createModelSnapshot(String jobId, String snapshotId) throws IOException {
2733-
createModelSnapshot(jobId, snapshotId, Version.CURRENT);
2734-
}
2735-
2736-
public void createModelSnapshot(String jobId, String snapshotId, Version minVersion) throws IOException {
27372733
String documentId = jobId + "_model_snapshot_" + snapshotId;
27382734
Job job = MachineLearningIT.buildJob(jobId);
27392735
highLevelClient().machineLearning().putJob(new PutJobRequest(job), RequestOptions.DEFAULT);
@@ -2747,7 +2743,7 @@ public void createModelSnapshot(String jobId, String snapshotId, Version minVers
27472743
"\"total_by_field_count\":3, \"total_over_field_count\":0, \"total_partition_field_count\":2," +
27482744
"\"bucket_allocation_failures_count\":0, \"memory_status\":\"ok\", \"log_time\":1541587919000, " +
27492745
"\"timestamp\":1519930800000}, \"latest_record_time_stamp\":1519931700000," +
2750-
"\"latest_result_time_stamp\":1519930800000, \"retain\":false, \"min_version\":\"" + minVersion.toString() + "\"}",
2746+
"\"latest_result_time_stamp\":1519930800000, \"retain\":false, \"min_version\":\"" + Version.CURRENT.toString() + "\"}",
27512747
XContentType.JSON);
27522748

27532749
highLevelClient().index(indexRequest, RequestOptions.DEFAULT);
@@ -2828,12 +2824,11 @@ public void testUpdateModelSnapshot() throws Exception {
28282824
getModelSnapshotsResponse2.snapshots().get(0).getDescription());
28292825
}
28302826

2831-
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/65699")
28322827
public void testUpgradeJobSnapshot() throws Exception {
28332828
String jobId = "test-upgrade-model-snapshot";
28342829
String snapshotId = "1541587919";
28352830

2836-
createModelSnapshot(jobId, snapshotId, Version.CURRENT);
2831+
createModelSnapshot(jobId, snapshotId);
28372832
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
28382833
UpgradeJobModelSnapshotRequest request = new UpgradeJobModelSnapshotRequest(jobId, snapshotId, null, true);
28392834
ElasticsearchException ex = expectThrows(ElasticsearchException.class,

client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/MlClientDocumentationIT.java

+1-2
Original file line numberDiff line numberDiff line change
@@ -2336,7 +2336,6 @@ public void onFailure(Exception e) {
23362336
}
23372337
}
23382338

2339-
@AwaitsFix(bugUrl="https://github.com/elastic/elasticsearch/issues/65699")
23402339
public void testUpgradeJobSnapshot() throws IOException, InterruptedException {
23412340
RestHighLevelClient client = highLevelClient();
23422341

@@ -2376,7 +2375,7 @@ public void testUpgradeJobSnapshot() throws IOException, InterruptedException {
23762375
// end::upgrade-job-model-snapshot-execute
23772376
fail("upgrade model snapshot should not have succeeded.");
23782377
} catch (ElasticsearchException ex) {
2379-
assertThat(ex.getMessage(), containsString("Expected persisted state but no state exists"));
2378+
assertThat(ex.getMessage(), containsString("Unexpected state [failed] while waiting for to be assigned to a node"));
23802379
}
23812380
UpgradeJobModelSnapshotResponse response = new UpgradeJobModelSnapshotResponse(true, "");
23822381

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/JobModelSnapshotUpgrader.java

+57-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.threadpool.ThreadPool;
2424
import org.elasticsearch.xpack.core.ml.job.config.AnalysisConfig;
2525
import org.elasticsearch.xpack.core.ml.job.config.Job;
26+
import org.elasticsearch.xpack.core.ml.job.process.autodetect.output.FlushAcknowledgement;
2627
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeState;
2728
import org.elasticsearch.xpack.core.ml.job.snapshot.upgrade.SnapshotUpgradeTaskState;
2829
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -31,11 +32,13 @@
3132
import org.elasticsearch.xpack.ml.job.persistence.StateStreamer;
3233
import org.elasticsearch.xpack.ml.job.process.autodetect.output.JobSnapshotUpgraderResultProcessor;
3334
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
35+
import org.elasticsearch.xpack.ml.job.process.autodetect.params.FlushJobParams;
3436
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTask;
3537
import org.elasticsearch.xpack.ml.process.NativeStorageProvider;
3638
import org.elasticsearch.xpack.ml.process.writer.LengthEncodedWriter;
3739

3840
import java.io.IOException;
41+
import java.time.Duration;
3942
import java.util.HashMap;
4043
import java.util.Map;
4144
import java.util.Objects;
@@ -48,7 +51,7 @@
4851
import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME;
4952

5053
public final class JobModelSnapshotUpgrader {
51-
54+
private static final Duration FLUSH_PROCESS_CHECK_FREQUENCY = Duration.ofSeconds(1);
5255
private static final Logger logger = LogManager.getLogger(JobModelSnapshotUpgrader.class);
5356

5457
private final SnapshotUpgradeTask task;
@@ -97,7 +100,9 @@ void start() {
97100
params,
98101
autodetectExecutorService,
99102
(reason) -> {
100-
setTaskToFailed(reason, ActionListener.wrap(t -> {}, f -> {}));
103+
setTaskToFailed(reason, ActionListener.wrap(t -> {
104+
}, f -> {
105+
}));
101106
try {
102107
nativeStorageProvider.cleanupLocalTmpStorage(task.getDescription());
103108
} catch (IOException e) {
@@ -200,6 +205,24 @@ void writeHeader() throws IOException {
200205
process.writeRecord(record);
201206
}
202207

208+
FlushAcknowledgement waitFlushToCompletion(String flushId) throws Exception {
209+
logger.debug(() -> new ParameterizedMessage("[{}] [{}] waiting for flush [{}]", jobId, snapshotId, flushId));
210+
211+
FlushAcknowledgement flushAcknowledgement;
212+
try {
213+
flushAcknowledgement = processor.waitForFlushAcknowledgement(flushId, FLUSH_PROCESS_CHECK_FREQUENCY);
214+
while (flushAcknowledgement == null) {
215+
checkProcessIsAlive();
216+
checkResultsProcessorIsAlive();
217+
flushAcknowledgement = processor.waitForFlushAcknowledgement(flushId, FLUSH_PROCESS_CHECK_FREQUENCY);
218+
}
219+
} finally {
220+
processor.clearAwaitingFlush(flushId);
221+
}
222+
logger.debug(() -> new ParameterizedMessage("[{}] [{}] flush completed [{}]", jobId, snapshotId, flushId));
223+
return flushAcknowledgement;
224+
}
225+
203226
void restoreState() {
204227
try {
205228
process.restoreState(stateStreamer, params.modelSnapshot());
@@ -209,6 +232,31 @@ void restoreState() {
209232
ActionListener.wrap(t -> shutdown(e), f -> shutdown(e)));
210233
return;
211234
}
235+
submitOperation(() -> {
236+
String flushId = process.flushJob(FlushJobParams.builder().waitForNormalization(false).build());
237+
return waitFlushToCompletion(flushId);
238+
}, (aVoid, e) -> {
239+
Runnable nextStep;
240+
if (e != null) {
241+
logger.error(
242+
() -> new ParameterizedMessage(
243+
"[{}] [{}] failed to flush after writing old state",
244+
jobId,
245+
snapshotId
246+
),
247+
e);
248+
nextStep = () -> setTaskToFailed(
249+
"Failed to flush after writing old state due to: " + e.getMessage(),
250+
ActionListener.wrap(t -> shutdown(e), f -> shutdown(e))
251+
);
252+
} else {
253+
nextStep = this::requestStateWrite;
254+
}
255+
threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(nextStep);
256+
});
257+
}
258+
259+
private void requestStateWrite() {
212260
task.updatePersistentTaskState(
213261
new SnapshotUpgradeTaskState(SnapshotUpgradeState.SAVING_NEW_STATE, task.getAllocationId(), ""),
214262
ActionListener.wrap(
@@ -282,6 +330,13 @@ private void checkProcessIsAlive() {
282330
}
283331
}
284332

333+
private void checkResultsProcessorIsAlive() {
334+
if (processor.isFailed()) {
335+
// Don't log here - it just causes double logging when the exception gets logged
336+
throw new ElasticsearchException("[{}] Unexpected death of the result processor", job.getId());
337+
}
338+
}
339+
285340
void shutdown(Exception e) {
286341
// No point in sending an action to the executor if the process has died
287342
if (process.isProcessAlive() == false) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/output/JobSnapshotUpgraderResultProcessor.java

+29-2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.apache.logging.log4j.message.ParameterizedMessage;
1111
import org.elasticsearch.action.bulk.BulkResponse;
1212
import org.elasticsearch.action.support.WriteRequest;
13+
import org.elasticsearch.common.Nullable;
1314
import org.elasticsearch.xpack.core.ml.MachineLearningField;
1415
import org.elasticsearch.xpack.core.ml.annotations.Annotation;
1516
import org.elasticsearch.xpack.core.ml.job.process.autodetect.output.FlushAcknowledgement;
@@ -28,6 +29,7 @@
2829
import org.elasticsearch.xpack.ml.job.process.autodetect.AutodetectProcess;
2930
import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
3031

32+
import java.time.Duration;
3133
import java.util.Iterator;
3234
import java.util.List;
3335
import java.util.Objects;
@@ -52,6 +54,7 @@ public class JobSnapshotUpgraderResultProcessor {
5254
private final JobResultsPersister persister;
5355
private final AutodetectProcess process;
5456
private final JobResultsPersister.Builder bulkResultsPersister;
57+
private final FlushListener flushListener;
5558
private volatile boolean processKilled;
5659
private volatile boolean failed;
5760

@@ -64,6 +67,7 @@ public JobSnapshotUpgraderResultProcessor(String jobId,
6467
this.persister = Objects.requireNonNull(persister);
6568
this.process = Objects.requireNonNull(autodetectProcess);
6669
this.bulkResultsPersister = persister.bulkPersisterBuilder(jobId).shouldRetry(this::isAlive);
70+
this.flushListener = new FlushListener();
6771
}
6872

6973
public void process() {
@@ -204,10 +208,34 @@ void processResult(AutodetectResult result) {
204208
}
205209
FlushAcknowledgement flushAcknowledgement = result.getFlushAcknowledgement();
206210
if (flushAcknowledgement != null) {
207-
logUnexpectedResult(FlushAcknowledgement.TYPE.getPreferredName());
211+
LOGGER.debug(
212+
() -> new ParameterizedMessage(
213+
"[{}] [{}] Flush acknowledgement parsed from output for ID {}",
214+
jobId,
215+
snapshotId,
216+
flushAcknowledgement.getId()
217+
)
218+
);
219+
flushListener.acknowledgeFlush(flushAcknowledgement, null);
208220
}
209221
}
210222

223+
/**
224+
* Blocks until a flush is acknowledged or the timeout expires, whichever happens first.
225+
*
226+
* @param flushId the id of the flush request to wait for
227+
* @param timeout the timeout
228+
* @return The {@link FlushAcknowledgement} if the flush has completed or the parsing finished; {@code null} if the timeout expired
229+
*/
230+
@Nullable
231+
public FlushAcknowledgement waitForFlushAcknowledgement(String flushId, Duration timeout) throws Exception {
232+
return failed ? null : flushListener.waitForFlush(flushId, timeout);
233+
}
234+
235+
public void clearAwaitingFlush(String flushId) {
236+
flushListener.clear(flushId);
237+
}
238+
211239
public void awaitCompletion() throws TimeoutException {
212240
try {
213241
// Although the results won't take 30 minutes to finish, the pipe won't be closed
@@ -230,7 +258,6 @@ public void awaitCompletion() throws TimeoutException {
230258
}
231259
}
232260

233-
234261
/**
235262
* If failed then there was an error parsing the results that cannot be recovered from
236263
*

0 commit comments

Comments
 (0)