Skip to content

[ML] retry bulk indexing of state docs #50149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -559,11 +559,17 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
environment,
settings,
nativeController,
client,
clusterService);
clusterService,
resultsPersisterService,
anomalyDetectionAuditor);
normalizerProcessFactory = new NativeNormalizerProcessFactory(environment, nativeController, clusterService);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(environment, client, nativeController, clusterService,
xContentRegistry);
analyticsProcessFactory = new NativeAnalyticsProcessFactory(
environment,
nativeController,
clusterService,
xContentRegistry,
resultsPersisterService,
dataFrameAnalyticsAuditor);
memoryEstimationProcessFactory =
new NativeMemoryUsageEstimationProcessFactory(environment, nativeController, clusterService);
mlController = nativeController;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -20,10 +19,12 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -40,18 +41,24 @@ public class NativeAnalyticsProcessFactory implements AnalyticsProcessFactory<An

private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper();

private final Client client;
private final Environment env;
private final NativeController nativeController;
private final NamedXContentRegistry namedXContentRegistry;
private final ResultsPersisterService resultsPersisterService;
private final DataFrameAnalyticsAuditor auditor;
private volatile Duration processConnectTimeout;

public NativeAnalyticsProcessFactory(Environment env, Client client, NativeController nativeController, ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry) {
public NativeAnalyticsProcessFactory(Environment env,
NativeController nativeController,
ClusterService clusterService,
NamedXContentRegistry namedXContentRegistry,
ResultsPersisterService resultsPersisterService,
DataFrameAnalyticsAuditor auditor) {
this.env = Objects.requireNonNull(env);
this.client = Objects.requireNonNull(client);
this.nativeController = Objects.requireNonNull(nativeController);
this.namedXContentRegistry = Objects.requireNonNull(namedXContentRegistry);
this.auditor = auditor;
this.resultsPersisterService = resultsPersisterService;
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(env.settings()));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
this::setProcessConnectTimeout);
Expand Down Expand Up @@ -96,7 +103,7 @@ public NativeAnalyticsProcess createAnalyticsProcess(DataFrameAnalyticsConfig co
private void startProcess(DataFrameAnalyticsConfig config, ExecutorService executorService, ProcessPipes processPipes,
NativeAnalyticsProcess process) {
if (config.getAnalysis().persistsState()) {
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, config.getId());
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(config.getId(), resultsPersisterService, auditor);
process.start(executorService, stateProcessor, processPipes.getPersistStream().get());
} else {
process.start(executorService);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
Expand All @@ -20,11 +19,13 @@
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.job.results.AutodetectResult;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.process.IndexingStateProcessor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.process.ProcessResultsParser;
import org.elasticsearch.xpack.ml.utils.NamedPipeHelper;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.nio.file.Path;
Expand All @@ -40,20 +41,26 @@ public class NativeAutodetectProcessFactory implements AutodetectProcessFactory
private static final Logger LOGGER = LogManager.getLogger(NativeAutodetectProcessFactory.class);
private static final NamedPipeHelper NAMED_PIPE_HELPER = new NamedPipeHelper();

private final Client client;
private final Environment env;
private final Settings settings;
private final NativeController nativeController;
private final ClusterService clusterService;
private final ResultsPersisterService resultsPersisterService;
private final AnomalyDetectionAuditor auditor;
private volatile Duration processConnectTimeout;

public NativeAutodetectProcessFactory(Environment env, Settings settings, NativeController nativeController, Client client,
ClusterService clusterService) {
public NativeAutodetectProcessFactory(Environment env,
Settings settings,
NativeController nativeController,
ClusterService clusterService,
ResultsPersisterService resultsPersisterService,
AnomalyDetectionAuditor auditor) {
this.env = Objects.requireNonNull(env);
this.settings = Objects.requireNonNull(settings);
this.nativeController = Objects.requireNonNull(nativeController);
this.client = client;
this.clusterService = clusterService;
this.resultsPersisterService = resultsPersisterService;
this.auditor = auditor;
setProcessConnectTimeout(MachineLearning.PROCESS_CONNECT_TIMEOUT.get(settings));
clusterService.getClusterSettings().addSettingsUpdateConsumer(MachineLearning.PROCESS_CONNECT_TIMEOUT,
this::setProcessConnectTimeout);
Expand All @@ -78,7 +85,7 @@ public AutodetectProcess createAutodetectProcess(Job job,
// The extra 1 is the control field
int numberOfFields = job.allInputFields().size() + (includeTokensField ? 1 : 0) + 1;

IndexingStateProcessor stateProcessor = new IndexingStateProcessor(client, job.getId());
IndexingStateProcessor stateProcessor = new IndexingStateProcessor(job.getId(), resultsPersisterService, auditor);
ProcessResultsParser<AutodetectResult> resultsParser = new ProcessResultsParser<>(AutodetectResult.PARSER,
NamedXContentRegistry.EMPTY);
NativeAutodetectProcess autodetect = new NativeAutodetectProcess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,22 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.CompositeBytesReference;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditMessage;
import org.elasticsearch.xpack.core.common.notifications.AbstractAuditor;
import org.elasticsearch.xpack.core.ml.job.persistence.AnomalyDetectorsIndex;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;

/**
* Reads state documents of a stream, splits them and persists to an index via a bulk request
Expand All @@ -32,12 +33,16 @@ public class IndexingStateProcessor implements StateProcessor {

private static final int READ_BUF_SIZE = 8192;

private final Client client;
private final String jobId;
private final AbstractAuditor<? extends AbstractAuditMessage> auditor;
private final ResultsPersisterService resultsPersisterService;

public IndexingStateProcessor(Client client, String jobId) {
this.client = client;
public IndexingStateProcessor(String jobId,
ResultsPersisterService resultsPersisterService,
AbstractAuditor<? extends AbstractAuditMessage> auditor) {
this.jobId = jobId;
this.resultsPersisterService = resultsPersisterService;
this.auditor = auditor;
}

@Override
Expand Down Expand Up @@ -98,8 +103,15 @@ void persist(BytesReference bytes) throws IOException {
bulkRequest.add(bytes, AnomalyDetectorsIndex.jobStateIndexWriteAlias(), XContentType.JSON);
if (bulkRequest.numberOfActions() > 0) {
LOGGER.trace("[{}] Persisting job state document", jobId);
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashWithOrigin(ML_ORIGIN)) {
client.bulk(bulkRequest).actionGet();
try {
resultsPersisterService.bulkIndexWithRetry(bulkRequest,
jobId,
() -> true,
(msg) -> auditor.warning(jobId, "Bulk indexing of state failed " + msg));
} catch (Exception ex) {
String msg = "failed indexing updated state docs";
LOGGER.error(() -> new ParameterizedMessage("[{}] {}", jobId, msg), ex);
auditor.error(jobId, msg + " error: " + ex.getMessage());
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.ml.job.process.autodetect;

import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -16,8 +15,10 @@
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.ml.MachineLearning;
import org.elasticsearch.xpack.ml.job.process.autodetect.params.AutodetectParams;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.process.NativeController;
import org.elasticsearch.xpack.ml.process.ProcessPipes;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;

import java.io.IOException;
import java.time.Duration;
Expand All @@ -41,7 +42,8 @@ public void testSetProcessConnectTimeout() throws IOException {
.build();
Environment env = TestEnvironment.newEnvironment(settings);
NativeController nativeController = mock(NativeController.class);
Client client = mock(Client.class);
ResultsPersisterService resultsPersisterService = mock(ResultsPersisterService.class);
AnomalyDetectionAuditor anomalyDetectionAuditor = mock(AnomalyDetectionAuditor.class);
ClusterSettings clusterSettings = new ClusterSettings(settings,
Set.of(MachineLearning.PROCESS_CONNECT_TIMEOUT, AutodetectBuilder.MAX_ANOMALY_RECORDS_SETTING_DYNAMIC));
ClusterService clusterService = mock(ClusterService.class);
Expand All @@ -51,8 +53,13 @@ public void testSetProcessConnectTimeout() throws IOException {
AutodetectParams autodetectParams = mock(AutodetectParams.class);
ProcessPipes processPipes = mock(ProcessPipes.class);

NativeAutodetectProcessFactory nativeAutodetectProcessFactory =
new NativeAutodetectProcessFactory(env, settings, nativeController, client, clusterService);
NativeAutodetectProcessFactory nativeAutodetectProcessFactory = new NativeAutodetectProcessFactory(
env,
settings,
nativeController,
clusterService,
resultsPersisterService,
anomalyDetectionAuditor);
nativeAutodetectProcessFactory.setProcessConnectTimeout(TimeValue.timeValueSeconds(timeoutSeconds));
nativeAutodetectProcessFactory.createNativeProcess(job, autodetectParams, processPipes, Collections.emptyList());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
package org.elasticsearch.xpack.ml.process;

import com.carrotsearch.randomizedtesting.annotations.Timeout;
import org.elasticsearch.action.ActionFuture;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.mock.orig.Mockito;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.utils.persistence.ResultsPersisterService;
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
Expand Down Expand Up @@ -54,24 +54,22 @@ public class IndexingStateProcessorTests extends ESTestCase {
private static final int NUM_LARGE_DOCS = 2;
private static final int LARGE_DOC_SIZE = 1000000;

private Client client;
private IndexingStateProcessor stateProcessor;
private ResultsPersisterService resultsPersisterService;

@Before
public void initialize() throws IOException {
client = mock(Client.class);
@SuppressWarnings("unchecked")
ActionFuture<BulkResponse> bulkResponseFuture = mock(ActionFuture.class);
stateProcessor = spy(new IndexingStateProcessor(client, JOB_ID));
when(client.bulk(any(BulkRequest.class))).thenReturn(bulkResponseFuture);
public void initialize() {
resultsPersisterService = mock(ResultsPersisterService.class);
AnomalyDetectionAuditor auditor = mock(AnomalyDetectionAuditor.class);
stateProcessor = spy(new IndexingStateProcessor(JOB_ID, resultsPersisterService, auditor));
when(resultsPersisterService.bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any())).thenReturn(mock(BulkResponse.class));
ThreadPool threadPool = mock(ThreadPool.class);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
}

@After
public void verifyNoMoreClientInteractions() {
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

public void testStateRead() throws IOException {
Expand All @@ -85,8 +83,7 @@ public void testStateRead() throws IOException {
assertEquals(threeStates[0], capturedBytes.get(0).utf8ToString());
assertEquals(threeStates[1], capturedBytes.get(1).utf8ToString());
assertEquals(threeStates[2], capturedBytes.get(2).utf8ToString());
verify(client, times(3)).bulk(any(BulkRequest.class));
verify(client, times(3)).threadPool();
verify(resultsPersisterService, times(3)).bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any());
}

public void testStateReadGivenConsecutiveZeroBytes() throws IOException {
Expand All @@ -96,7 +93,7 @@ public void testStateReadGivenConsecutiveZeroBytes() throws IOException {
stateProcessor.process(stream);

verify(stateProcessor, never()).persist(any());
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte() throws IOException {
Expand All @@ -106,7 +103,7 @@ public void testStateReadGivenConsecutiveSpacesFollowedByZeroByte() throws IOExc
stateProcessor.process(stream);

verify(stateProcessor, times(1)).persist(any());
Mockito.verifyNoMoreInteractions(client);
Mockito.verifyNoMoreInteractions(resultsPersisterService);
}

/**
Expand All @@ -128,7 +125,6 @@ public void testLargeStateRead() throws Exception {
ByteArrayInputStream stream = new ByteArrayInputStream(builder.toString().getBytes(StandardCharsets.UTF_8));
stateProcessor.process(stream);
verify(stateProcessor, times(NUM_LARGE_DOCS)).persist(any());
verify(client, times(NUM_LARGE_DOCS)).bulk(any(BulkRequest.class));
verify(client, times(NUM_LARGE_DOCS)).threadPool();
verify(resultsPersisterService, times(NUM_LARGE_DOCS)).bulkIndexWithRetry(any(BulkRequest.class), any(), any(), any());
}
}