Skip to content

[ML] Prepare parsing phase_progress from DFA process #55580

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.security.user.XPackUser;
import org.elasticsearch.xpack.ml.dataframe.process.results.AnalyticsResult;
import org.elasticsearch.xpack.ml.dataframe.process.results.RowResults;
Expand Down Expand Up @@ -164,11 +165,20 @@ private void processResult(AnalyticsResult result, DataFrameRowsJoiner resultsJo
if (rowResults != null) {
resultsJoiner.processRowResults(rowResults);
}
PhaseProgress phaseProgress = result.getPhaseProgress();
if (phaseProgress != null) {
LOGGER.debug("[{}] progress for phase [{}] updated to [{}]", analytics.getId(), phaseProgress.getPhase(),
phaseProgress.getProgressPercent());
statsHolder.getProgressTracker().analyzingPercent.set(phaseProgress.getProgressPercent());
}

// TODO remove after process is writing out phase_progress
Integer progressPercent = result.getProgressPercent();
if (progressPercent != null) {
LOGGER.debug("[{}] Analyzing progress updated to [{}]", analytics.getId(), progressPercent);
statsHolder.getProgressTracker().analyzingPercent.set(progressPercent);
}

TrainedModelDefinition.Builder inferenceModelBuilder = result.getInferenceModelBuilder();
if (inferenceModelBuilder != null) {
createAndIndexInferenceModel(inferenceModelBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.classification.ClassificationStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.common.MemoryUsage;
import org.elasticsearch.xpack.core.ml.dataframe.stats.outlierdetection.OutlierDetectionStats;
import org.elasticsearch.xpack.core.ml.dataframe.stats.regression.RegressionStats;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;

import java.io.IOException;
import java.util.Collections;
Expand All @@ -28,6 +29,7 @@ public class AnalyticsResult implements ToXContentObject {

public static final ParseField TYPE = new ParseField("analytics_result");

private static final ParseField PHASE_PROGRESS = new ParseField("phase_progress");
private static final ParseField PROGRESS_PERCENT = new ParseField("progress_percent");
private static final ParseField INFERENCE_MODEL = new ParseField("inference_model");
private static final ParseField ANALYTICS_MEMORY_USAGE = new ParseField("analytics_memory_usage");
Expand All @@ -38,16 +40,18 @@ public class AnalyticsResult implements ToXContentObject {
public static final ConstructingObjectParser<AnalyticsResult, Void> PARSER = new ConstructingObjectParser<>(TYPE.getPreferredName(),
a -> new AnalyticsResult(
(RowResults) a[0],
(Integer) a[1],
(TrainedModelDefinition.Builder) a[2],
(MemoryUsage) a[3],
(OutlierDetectionStats) a[4],
(ClassificationStats) a[5],
(RegressionStats) a[6]
(PhaseProgress) a[1],
(Integer) a[2],
(TrainedModelDefinition.Builder) a[3],
(MemoryUsage) a[4],
(OutlierDetectionStats) a[5],
(ClassificationStats) a[6],
(RegressionStats) a[7]
));

static {
PARSER.declareObject(optionalConstructorArg(), RowResults.PARSER, RowResults.TYPE);
PARSER.declareObject(optionalConstructorArg(), PhaseProgress.PARSER, PHASE_PROGRESS);
PARSER.declareInt(optionalConstructorArg(), PROGRESS_PERCENT);
// TODO change back to STRICT_PARSER once native side is aligned
PARSER.declareObject(optionalConstructorArg(), TrainedModelDefinition.LENIENT_PARSER, INFERENCE_MODEL);
Expand All @@ -58,7 +62,11 @@ public class AnalyticsResult implements ToXContentObject {
}

private final RowResults rowResults;
private final PhaseProgress phaseProgress;

// TODO remove after process is writing out phase_progress
private final Integer progressPercent;

private final TrainedModelDefinition.Builder inferenceModelBuilder;
private final TrainedModelDefinition inferenceModel;
private final MemoryUsage memoryUsage;
Expand All @@ -67,13 +75,15 @@ public class AnalyticsResult implements ToXContentObject {
private final RegressionStats regressionStats;

public AnalyticsResult(@Nullable RowResults rowResults,
@Nullable PhaseProgress phaseProgress,
@Nullable Integer progressPercent,
@Nullable TrainedModelDefinition.Builder inferenceModelBuilder,
@Nullable MemoryUsage memoryUsage,
@Nullable OutlierDetectionStats outlierDetectionStats,
@Nullable ClassificationStats classificationStats,
@Nullable RegressionStats regressionStats) {
this.rowResults = rowResults;
this.phaseProgress = phaseProgress;
this.progressPercent = progressPercent;
this.inferenceModelBuilder = inferenceModelBuilder;
this.inferenceModel = inferenceModelBuilder == null ? null : inferenceModelBuilder.build();
Expand All @@ -87,6 +97,10 @@ public RowResults getRowResults() {
return rowResults;
}

public PhaseProgress getPhaseProgress() {
return phaseProgress;
}

public Integer getProgressPercent() {
return progressPercent;
}
Expand Down Expand Up @@ -117,6 +131,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
if (rowResults != null) {
builder.field(RowResults.TYPE.getPreferredName(), rowResults);
}
if (phaseProgress != null) {
builder.field(PHASE_PROGRESS.getPreferredName(), phaseProgress);
}
if (progressPercent != null) {
builder.field(PROGRESS_PERCENT.getPreferredName(), progressPercent);
}
Expand Down Expand Up @@ -152,6 +169,7 @@ public boolean equals(Object other) {

AnalyticsResult that = (AnalyticsResult) other;
return Objects.equals(rowResults, that.rowResults)
&& Objects.equals(phaseProgress, that.phaseProgress)
&& Objects.equals(progressPercent, that.progressPercent)
&& Objects.equals(inferenceModel, that.inferenceModel)
&& Objects.equals(memoryUsage, that.memoryUsage)
Expand All @@ -162,7 +180,7 @@ public boolean equals(Object other) {

@Override
public int hashCode() {
return Objects.hash(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats);
return Objects.hash(rowResults, phaseProgress, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats,
classificationStats, regressionStats);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public class AnalyticsProcessManagerTests extends ESTestCase {
private static final String CONFIG_ID = "config-id";
private static final int NUM_ROWS = 100;
private static final int NUM_COLS = 4;
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null);
private static final AnalyticsResult PROCESS_RESULT = new AnalyticsResult(null, null, null, null, null, null, null, null);

private Client client;
private DataFrameAnalyticsAuditor auditor;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ public void testProcess_GivenNoResults() {
public void testProcess_GivenEmptyResults() {
givenDataFrameRows(2);
givenProcessResults(Arrays.asList(
new AnalyticsResult(null, 50, null, null, null, null, null),
new AnalyticsResult(null, 100, null, null, null, null, null)));
new AnalyticsResult(null, null,50, null, null, null, null, null),
new AnalyticsResult(null, null, 100, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();

resultProcessor.process(process);
Expand All @@ -121,8 +121,8 @@ public void testProcess_GivenRowResults() {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,50, null, null, null, null, null),
new AnalyticsResult(rowResults2, null, 100, null, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();

resultProcessor.process(process);
Expand All @@ -139,8 +139,8 @@ public void testProcess_GivenDataFrameRowsJoinerFails() {
givenDataFrameRows(2);
RowResults rowResults1 = mock(RowResults.class);
RowResults rowResults2 = mock(RowResults.class);
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, 50, null, null, null, null, null),
new AnalyticsResult(rowResults2, 100, null, null, null, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(rowResults1, null,50, null, null, null, null, null),
new AnalyticsResult(rowResults2, null, 100, null, null, null, null, null)));

doThrow(new RuntimeException("some failure")).when(dataFrameRowsJoiner).processRowResults(any(RowResults.class));

Expand Down Expand Up @@ -174,7 +174,7 @@ public void testProcess_GivenInferenceModelIsStoredSuccessfully() {
extractedFieldList.add(new DocValueField("baz", Collections.emptySet()));
TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor(extractedFieldList);

resultProcessor.process(process);
Expand Down Expand Up @@ -238,7 +238,7 @@ public void testProcess_GivenInferenceModelFailedToStore() {

TargetType targetType = analyticsConfig.getAnalysis() instanceof Regression ? TargetType.REGRESSION : TargetType.CLASSIFICATION;
TrainedModelDefinition.Builder inferenceModel = TrainedModelDefinitionTests.createRandomBuilder(targetType);
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, inferenceModel, null, null, null, null)));
givenProcessResults(Arrays.asList(new AnalyticsResult(null, null, null, inferenceModel, null, null, null, null)));
AnalyticsResultProcessor resultProcessor = createResultProcessor();

resultProcessor.process(process);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinition;
import org.elasticsearch.xpack.core.ml.inference.TrainedModelDefinitionTests;
import org.elasticsearch.xpack.core.ml.utils.PhaseProgress;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;

import java.util.ArrayList;
Expand All @@ -41,6 +42,7 @@ protected NamedXContentRegistry xContentRegistry() {
@Override
protected AnalyticsResult createTestInstance() {
RowResults rowResults = null;
PhaseProgress phaseProgress = null;
Integer progressPercent = null;
TrainedModelDefinition.Builder inferenceModel = null;
MemoryUsage memoryUsage = null;
Expand All @@ -50,6 +52,9 @@ protected AnalyticsResult createTestInstance() {
if (randomBoolean()) {
rowResults = RowResultsTests.createRandom();
}
if (randomBoolean()) {
phaseProgress = new PhaseProgress(randomAlphaOfLength(10), randomIntBetween(0, 100));
}
if (randomBoolean()) {
progressPercent = randomIntBetween(0, 100);
}
Expand All @@ -68,8 +73,8 @@ protected AnalyticsResult createTestInstance() {
if (randomBoolean()) {
regressionStats = RegressionStatsTests.createRandom();
}
return new AnalyticsResult(rowResults, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats, classificationStats,
regressionStats);
return new AnalyticsResult(rowResults, phaseProgress, progressPercent, inferenceModel, memoryUsage, outlierDetectionStats,
classificationStats, regressionStats);
}

@Override
Expand Down