diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java index 9f938e87a72e3..47d9628939fa0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManager.java @@ -434,7 +434,6 @@ synchronized void stop() { if (inferenceRunner.get() != null) { inferenceRunner.get().cancel(); } - statsPersister.cancel(); if (process.get() != null) { try { process.get().kill(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java index 32af30a93da4a..41dcea15577e8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessor.java @@ -56,6 +56,7 @@ public class AnalyticsResultProcessor { private final ChunkedTrainedModelPersister chunkedTrainedModelPersister; private volatile String failure; private volatile boolean isCancelled; + private long processedRows; private volatile String latestModelId; @@ -92,31 +93,17 @@ public void awaitForCompletion() { public void cancel() { dataFrameRowsJoiner.cancel(); - statsPersister.cancel(); isCancelled = true; } public void process(AnalyticsProcess process) { long totalRows = process.getConfig().rows(); - long processedRows = 0; // TODO When java 9 features can be used, we will not need the local variable here try (DataFrameRowsJoiner resultsJoiner = dataFrameRowsJoiner) { Iterator iterator = process.readAnalyticsResults(); while (iterator.hasNext()) { - if (isCancelled) { - break; - } - AnalyticsResult result = iterator.next(); - processResult(result, resultsJoiner); - if (result.getRowResults() != null) { - if (processedRows == 0) { - LOGGER.info("[{}] Started writing results", analytics.getId()); - auditor.info(analytics.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_WRITING_RESULTS)); - } - processedRows++; - updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); - } + processResult(iterator.next(), resultsJoiner, totalRows); } } catch (Exception e) { if (isCancelled) { @@ -141,10 +128,10 @@ private void completeResultsProgress() { statsHolder.getProgressTracker().updateWritingResultsProgress(100); } - private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner) { + private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJoiner, long totalRows) { RowResults rowResults = result.getRowResults(); - if (rowResults != null) { - resultsJoiner.processRowResults(rowResults); + if (rowResults != null && isCancelled == false) { + processRowResult(resultsJoiner, totalRows, rowResults); } PhaseProgress phaseProgress = result.getPhaseProgress(); if (phaseProgress != null) { @@ -157,7 +144,7 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo latestModelId = chunkedTrainedModelPersister.createAndIndexInferenceModelMetadata(modelSize); } TrainedModelDefinitionChunk trainedModelDefinitionChunk = result.getTrainedModelDefinitionChunk(); - if (trainedModelDefinitionChunk != null) { + if (trainedModelDefinitionChunk != null && isCancelled == false) { chunkedTrainedModelPersister.createAndIndexInferenceModelDoc(trainedModelDefinitionChunk); } MemoryUsage memoryUsage = result.getMemoryUsage(); @@ -181,6 +168,16 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo } } + private void processRowResult(DataFrameRowsJoiner rowsJoiner, long totalRows, RowResults rowResults) { + rowsJoiner.processRowResults(rowResults); + if (processedRows == 0) { + LOGGER.info("[{}] Started writing results", analytics.getId()); + auditor.info(analytics.getId(), Messages.getMessage(Messages.DATA_FRAME_ANALYTICS_AUDIT_STARTED_WRITING_RESULTS)); + } + processedRows++; + updateResultsProgress(processedRows >= totalRows ? 100 : (int) (processedRows * 100.0 / totalRows)); + } + private void setAndReportFailure(Exception e) { LOGGER.error(new ParameterizedMessage("[{}] Error processing results; ", analytics.getId()), e); failure = "error processing results; " + e.getMessage(); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java index c7e6f9a1c4377..0020c2df8bc88 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResult.java @@ -66,14 +66,14 @@ public class AnalyticsResult implements ToXContentObject { private final ModelSizeInfo modelSizeInfo; private final TrainedModelDefinitionChunk trainedModelDefinitionChunk; - public AnalyticsResult(@Nullable RowResults rowResults, - @Nullable PhaseProgress phaseProgress, - @Nullable MemoryUsage memoryUsage, - @Nullable OutlierDetectionStats outlierDetectionStats, - @Nullable ClassificationStats classificationStats, - @Nullable RegressionStats regressionStats, - @Nullable ModelSizeInfo modelSizeInfo, - @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + private AnalyticsResult(@Nullable RowResults rowResults, + @Nullable PhaseProgress phaseProgress, + @Nullable MemoryUsage memoryUsage, + @Nullable OutlierDetectionStats outlierDetectionStats, + @Nullable ClassificationStats classificationStats, + @Nullable RegressionStats regressionStats, + @Nullable ModelSizeInfo modelSizeInfo, + @Nullable TrainedModelDefinitionChunk trainedModelDefinitionChunk) { this.rowResults = rowResults; this.phaseProgress = phaseProgress; this.memoryUsage = memoryUsage; @@ -172,4 +172,75 @@ public int hashCode() { return Objects.hash(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); } + + public static Builder builder() { + return new Builder(); + } + + public static class Builder { + + private RowResults rowResults; + private PhaseProgress phaseProgress; + private MemoryUsage memoryUsage; + private OutlierDetectionStats outlierDetectionStats; + private ClassificationStats classificationStats; + private RegressionStats regressionStats; + private ModelSizeInfo modelSizeInfo; + private TrainedModelDefinitionChunk trainedModelDefinitionChunk; + + private Builder() {} + + public Builder setRowResults(RowResults rowResults) { + this.rowResults = rowResults; + return this; + } + + public Builder setPhaseProgress(PhaseProgress phaseProgress) { + this.phaseProgress = phaseProgress; + return this; + } + + public Builder setMemoryUsage(MemoryUsage memoryUsage) { + this.memoryUsage = memoryUsage; + return this; + } + + public Builder setOutlierDetectionStats(OutlierDetectionStats outlierDetectionStats) { + this.outlierDetectionStats = outlierDetectionStats; + return this; + } + + public Builder setClassificationStats(ClassificationStats classificationStats) { + this.classificationStats = classificationStats; + return this; + } + + public Builder setRegressionStats(RegressionStats regressionStats) { + this.regressionStats = regressionStats; + return this; + } + + public Builder setModelSizeInfo(ModelSizeInfo modelSizeInfo) { + this.modelSizeInfo = modelSizeInfo; + return this; + } + + public Builder setTrainedModelDefinitionChunk(TrainedModelDefinitionChunk trainedModelDefinitionChunk) { + this.trainedModelDefinitionChunk = trainedModelDefinitionChunk; + return this; + } + + public AnalyticsResult build() { + return new AnalyticsResult( + rowResults, + phaseProgress, + memoryUsage, + outlierDetectionStats, + classificationStats, + regressionStats, + modelSizeInfo, + trainedModelDefinitionChunk + ); + } + } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java index e553ac1810729..4d8179e9fcc94 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/stats/StatsPersister.java @@ -29,7 +29,6 @@ public class StatsPersister { private final String jobId; private final ResultsPersisterService resultsPersisterService; private final DataFrameAnalyticsAuditor auditor; - private volatile boolean isCancelled; public StatsPersister(String jobId, ResultsPersisterService resultsPersisterService, DataFrameAnalyticsAuditor auditor) { this.jobId = Objects.requireNonNull(jobId); @@ -38,10 +37,6 @@ public StatsPersister(String jobId, ResultsPersisterService resultsPersisterServ } public void persistWithRetry(ToXContentObject result, Function docIdSupplier) { - if (isCancelled) { - return; - } - try { resultsPersisterService.indexWithRetry(jobId, MlStatsIndex.writeAlias(), @@ -49,7 +44,7 @@ public void persistWithRetry(ToXContentObject result, Function d new ToXContent.MapParams(Collections.singletonMap(ToXContentParams.FOR_INTERNAL_STORAGE, "true")), WriteRequest.RefreshPolicy.NONE, docIdSupplier.apply(jobId), - () -> isCancelled == false, + () -> true, errorMsg -> auditor.error(jobId, "failed to persist result with id [" + docIdSupplier.apply(jobId) + "]; " + errorMsg) ); @@ -59,8 +54,4 @@ public void persistWithRetry(ToXContentObject result, Function d LOGGER.error(() -> new ParameterizedMessage("[{}] Failed indexing stats result", jobId), e); } } - - public void cancel() { - isCancelled = true; - } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java index 0803246f18ca1..8d7e9d0742189 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsProcessManagerTests.java @@ -60,7 +60,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase { private static final String CONFIG_ID = "config-id"; private static final int NUM_ROWS = 100; private static final int NUM_COLS = 4; - private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null); + private static final AnalyticsResult PROCESS_RESULT = AnalyticsResult.builder().build(); private Client client; private DataFrameAnalyticsAuditor auditor; diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java index f2b33c30755f6..1e404360ae738 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/AnalyticsResultProcessorTests.java @@ -12,8 +12,17 @@ import org.elasticsearch.xpack.core.ml.dataframe.DataFrameAnalyticsSource; import org.elasticsearch.xpack.core.ml.dataframe.analyses.DataFrameAnalysis; import org.elasticsearch.xpack.core.ml.dataframe.analyses.Regression; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; +import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult; import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults; +import org.elasticsearch.xpack.ml.dataframe.process.results.TrainedModelDefinitionChunk; import org.elasticsearch.xpack.ml.dataframe.stats.ProgressTracker; import org.elasticsearch.xpack.ml.dataframe.stats.StatsHolder; import org.elasticsearch.xpack.ml.dataframe.stats.StatsPersister; @@ -26,12 +35,15 @@ import org.mockito.InOrder; import org.mockito.Mockito; +import java.time.Instant; import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.doThrow; @@ -85,8 +97,8 @@ public void testProcess_GivenNoResults() { public void testProcess_GivenEmptyResults() { givenDataFrameRows(2); givenProcessResults(Arrays.asList( - new AnalyticsResult(null, null, null,null, null, null, null, null), - new AnalyticsResult(null, null, null, null, null, null, null, null))); + AnalyticsResult.builder().build(), + AnalyticsResult.builder().build())); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -101,8 +113,9 @@ public void testProcess_GivenRowResults() { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); AnalyticsResultProcessor resultProcessor = createResultProcessor(); resultProcessor.process(process); @@ -119,8 +132,9 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { givenDataFrameRows(2); RowResults rowResults1 = mock(RowResults.class); RowResults rowResults2 = mock(RowResults.class); - givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null, null,null, null, null, null, null), - new AnalyticsResult(rowResults2, null, null, null, null, null, null, null))); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class)); @@ -138,6 +152,146 @@ public void testProcess_GivenDataFrameRowsJoinerFails() { assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); } + public void testCancel_GivenRowResults() { + givenDataFrameRows(2); + RowResults rowResults1 = mock(RowResults.class); + RowResults rowResults2 = mock(RowResults.class); + givenProcessResults(Arrays.asList( + AnalyticsResult.builder().setRowResults(rowResults1).build(), + AnalyticsResult.builder().setRowResults(rowResults2).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + } + + public void testCancel_GivenModelChunk() { + givenDataFrameRows(2); + TrainedModelDefinitionChunk modelChunk = mock(TrainedModelDefinitionChunk.class); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setTrainedModelDefinitionChunk(modelChunk).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + } + + public void testCancel_GivenPhaseProgress() { + givenDataFrameRows(2); + PhaseProgress phaseProgress = new PhaseProgress("analyzing", 18); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setPhaseProgress(phaseProgress).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + Optional testPhaseProgress = statsHolder.getProgressTracker().report().stream() + .filter(p -> p.getPhase().equals(phaseProgress.getPhase())) + .findAny(); + assertThat(testPhaseProgress.isPresent(), is(true)); + assertThat(testPhaseProgress.get().getProgressPercent(), equalTo(18)); + } + + public void testCancel_GivenMemoryUsage() { + givenDataFrameRows(2); + MemoryUsage memoryUsage = new MemoryUsage(analyticsConfig.getId(), Instant.now(), 1000L, MemoryUsage.Status.HARD_LIMIT, null); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setMemoryUsage(memoryUsage).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getMemoryUsage(), equalTo(memoryUsage)); + verify(statsPersister).persistWithRetry(eq(memoryUsage), any()); + } + + public void testCancel_GivenOutlierDetectionStats() { + givenDataFrameRows(2); + OutlierDetectionStats outlierDetectionStats = OutlierDetectionStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setOutlierDetectionStats(outlierDetectionStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(outlierDetectionStats)); + verify(statsPersister).persistWithRetry(eq(outlierDetectionStats), any()); + } + + public void testCancel_GivenClassificationStats() { + givenDataFrameRows(2); + ClassificationStats classificationStats = ClassificationStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setClassificationStats(classificationStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(classificationStats)); + verify(statsPersister).persistWithRetry(eq(classificationStats), any()); + } + + public void testCancel_GivenRegressionStats() { + givenDataFrameRows(2); + RegressionStats regressionStats = RegressionStatsTests.createRandom(); + givenProcessResults(Arrays.asList(AnalyticsResult.builder().setRegressionStats(regressionStats).build())); + AnalyticsResultProcessor resultProcessor = createResultProcessor(); + + resultProcessor.cancel(); + + resultProcessor.process(process); + resultProcessor.awaitForCompletion(); + + verify(dataFrameRowsJoiner).cancel(); + verify(dataFrameRowsJoiner).close(); + Mockito.verifyNoMoreInteractions(dataFrameRowsJoiner, trainedModelProvider); + assertThat(statsHolder.getProgressTracker().getWritingResultsProgressPercent(), equalTo(0)); + + assertThat(statsHolder.getAnalysisStats(), equalTo(regressionStats)); + verify(statsPersister).persistWithRetry(eq(regressionStats), any()); + } + private void givenProcessResults(List results) { when(process.readAnalyticsResults()).thenReturn(results.iterator()); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java index 66dce7d70a58e..35e562bca2459 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/dataframe/process/results/AnalyticsResultTests.java @@ -11,19 +11,14 @@ import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.search.SearchModule; import org.elasticsearch.test.AbstractXContentTestCase; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage; -import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStatsTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats; +import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsageTests; import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStatsTests; -import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats; import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStatsTests; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.utils.PhaseProgress; import org.elasticsearch.xpack.core.ml.utils.ToXContentParams; import org.elasticsearch.xpack.ml.inference.modelsize.MlModelSizeNamedXContentProvider; -import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfo; import org.elasticsearch.xpack.ml.inference.modelsize.ModelSizeInfoTests; import java.util.ArrayList; @@ -42,41 +37,34 @@ protected NamedXContentRegistry xContentRegistry() { } protected AnalyticsResult createTestInstance() { - RowResults rowResults = null; - PhaseProgress phaseProgress = null; - MemoryUsage memoryUsage = null; - OutlierDetectionStats outlierDetectionStats = null; - ClassificationStats classificationStats = null; - RegressionStats regressionStats = null; - ModelSizeInfo modelSizeInfo = null; - TrainedModelDefinitionChunk trainedModelDefinitionChunk = null; + AnalyticsResult.Builder builder = AnalyticsResult.builder(); + if (randomBoolean()) { - rowResults = RowResultsTests.createRandom(); + builder.setRowResults(RowResultsTests.createRandom()); } if (randomBoolean()) { - phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100)); + builder.setPhaseProgress(new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100))); } if (randomBoolean()) { - memoryUsage = MemoryUsageTests.createRandom(); + builder.setMemoryUsage(MemoryUsageTests.createRandom()); } if (randomBoolean()) { - outlierDetectionStats = OutlierDetectionStatsTests.createRandom(); + builder.setOutlierDetectionStats(OutlierDetectionStatsTests.createRandom()); } if (randomBoolean()) { - classificationStats = ClassificationStatsTests.createRandom(); + builder.setClassificationStats(ClassificationStatsTests.createRandom()); } if (randomBoolean()) { - regressionStats = RegressionStatsTests.createRandom(); + builder.setRegressionStats(RegressionStatsTests.createRandom()); } if (randomBoolean()) { - modelSizeInfo = ModelSizeInfoTests.createRandom(); + builder.setModelSizeInfo(ModelSizeInfoTests.createRandom()); } if (randomBoolean()) { String def = randomAlphaOfLengthBetween(100, 1000); - trainedModelDefinitionChunk = new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean()); + builder.setTrainedModelDefinitionChunk(new TrainedModelDefinitionChunk(def, randomIntBetween(0, 10), randomBoolean())); } - return new AnalyticsResult(rowResults, phaseProgress, memoryUsage, outlierDetectionStats, - classificationStats, regressionStats, modelSizeInfo, trainedModelDefinitionChunk); + return builder.build(); } @Override