Skip to content

Commit d07b11b

Browse files
[7.x][ML] Perform test inference on java (#58877) (#59298)
Since we are able to load the inference model and perform inference in java, we no longer need to rely on the analytics process to be performing test inference on the docs that were not used for training. The benefit is that we do not need to send test docs and fit them in memory of the c++ process. Backport of #58877 Co-authored-by: Dimitris Athanasiou <[email protected]> Co-authored-by: Benjamin Trent <[email protected]>
1 parent 86555ec commit d07b11b

File tree

47 files changed

+862
-186
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+862
-186
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Classification.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,11 @@ public InferenceConfig inferenceConfig(FieldInfo fieldInfo) {
370370
.build();
371371
}
372372

373+
@Override
374+
public boolean supportsInference() {
375+
return true;
376+
}
377+
373378
public static String extractJobIdFromStateDoc(String stateDocId) {
374379
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
375380
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/DataFrameAnalysis.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ public interface DataFrameAnalysis extends ToXContentObject, NamedWriteable {
7878
@Nullable
7979
InferenceConfig inferenceConfig(FieldInfo fieldInfo);
8080

81+
/**
82+
* @return {@code true} if this analysis trains a model that can be used for inference
83+
*/
84+
boolean supportsInference();
85+
8186
/**
8287
* Summarizes information about the fields that is necessary for analysis to generate
8388
* the parameters needed for the process configuration.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/OutlierDetection.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ public InferenceConfig inferenceConfig(FieldInfo fieldInfo) {
257257
return null;
258258
}
259259

260+
@Override
261+
public boolean supportsInference() {
262+
return false;
263+
}
264+
260265
public enum Method {
261266
LOF, LDOF, DISTANCE_KTH_NN, DISTANCE_KNN;
262267

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/dataframe/analyses/Regression.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,11 @@ public InferenceConfig inferenceConfig(FieldInfo fieldInfo) {
283283
.build();
284284
}
285285

286+
@Override
287+
public boolean supportsInference() {
288+
return true;
289+
}
290+
286291
public static String extractJobIdFromStateDoc(String stateDocId) {
287292
int suffixIndex = stateDocId.lastIndexOf(STATE_DOC_ID_SUFFIX);
288293
return suffixIndex <= 0 ? null : stateDocId.substring(0, suffixIndex);

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import java.io.IOException;
1919
import java.util.Collections;
20+
import java.util.Map;
21+
import java.util.LinkedHashMap;
2022
import java.util.List;
2123
import java.util.Objects;
2224
import java.util.stream.Collectors;
@@ -137,18 +139,20 @@ public Object predictedValue() {
137139
public void writeResult(IngestDocument document, String parentResultField) {
138140
ExceptionsHelper.requireNonNull(document, "document");
139141
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
140-
document.setFieldValue(parentResultField + "." + this.resultsField,
141-
predictionFieldType.transformPredictedValue(value(), valueAsString()));
142-
if (topClasses.size() > 0) {
143-
document.setFieldValue(parentResultField + "." + topNumClassesField,
144-
topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
142+
document.setFieldValue(parentResultField, asMap());
143+
}
144+
145+
@Override
146+
public Map<String, Object> asMap() {
147+
Map<String, Object> map = new LinkedHashMap<>();
148+
map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString()));
149+
if (topClasses.isEmpty() == false) {
150+
map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList()));
145151
}
146-
if (getFeatureImportance().size() > 0) {
147-
document.setFieldValue(parentResultField + "." + FEATURE_IMPORTANCE, getFeatureImportance()
148-
.stream()
149-
.map(FeatureImportance::toMap)
150-
.collect(Collectors.toList()));
152+
if (getFeatureImportance().isEmpty() == false) {
153+
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
151154
}
155+
return map;
152156
}
153157

154158
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResults.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@
99
import org.elasticsearch.common.xcontent.ToXContentFragment;
1010
import org.elasticsearch.ingest.IngestDocument;
1111

12+
import java.util.Map;
13+
1214
public interface InferenceResults extends NamedWriteable, ToXContentFragment {
1315

1416
void writeResult(IngestDocument document, String parentResultField);
1517

18+
Map<String, Object> asMap();
19+
1620
Object predictedValue();
1721
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import java.io.IOException;
1313
import java.util.Arrays;
14+
import java.util.Map;
1415
import java.util.Objects;
1516

1617
public class RawInferenceResults implements InferenceResults {
@@ -57,6 +58,10 @@ public void writeResult(IngestDocument document, String parentResultField) {
5758
throw new UnsupportedOperationException("[raw] does not support writing inference results");
5859
}
5960

61+
@Override
62+
public Map<String, Object> asMap() {
63+
throw new UnsupportedOperationException("[raw] does not support map conversion");
64+
}
6065
@Override
6166
public Object predictedValue() {
6267
return null;

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import java.io.IOException;
1717
import java.util.Collections;
18+
import java.util.LinkedHashMap;
1819
import java.util.List;
20+
import java.util.Map;
1921
import java.util.Objects;
2022
import java.util.stream.Collectors;
2123

@@ -85,13 +87,17 @@ public Object predictedValue() {
8587
public void writeResult(IngestDocument document, String parentResultField) {
8688
ExceptionsHelper.requireNonNull(document, "document");
8789
ExceptionsHelper.requireNonNull(parentResultField, "parentResultField");
88-
document.setFieldValue(parentResultField + "." + this.resultsField, value());
89-
if (getFeatureImportance().size() > 0) {
90-
document.setFieldValue(parentResultField + ".feature_importance", getFeatureImportance()
91-
.stream()
92-
.map(FeatureImportance::toMap)
93-
.collect(Collectors.toList()));
90+
document.setFieldValue(parentResultField, asMap());
91+
}
92+
93+
@Override
94+
public Map<String, Object> asMap() {
95+
Map<String, Object> map = new LinkedHashMap<>();
96+
map.put(resultsField, value());
97+
if (getFeatureImportance().isEmpty() == false) {
98+
map.put(FEATURE_IMPORTANCE, getFeatureImportance().stream().map(FeatureImportance::toMap).collect(Collectors.toList()));
9499
}
100+
return map;
95101
}
96102

97103
@Override

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
1414

1515
import java.io.IOException;
16+
import java.util.LinkedHashMap;
17+
import java.util.Map;
1618
import java.util.Objects;
1719

1820
public class WarningInferenceResults implements InferenceResults {
@@ -56,7 +58,14 @@ public int hashCode() {
5658
public void writeResult(IngestDocument document, String parentResultField) {
5759
ExceptionsHelper.requireNonNull(document, "document");
5860
ExceptionsHelper.requireNonNull(parentResultField, "resultField");
59-
document.setFieldValue(parentResultField + "." + NAME, warning);
61+
document.setFieldValue(parentResultField, asMap());
62+
}
63+
64+
@Override
65+
public Map<String, Object> asMap() {
66+
Map<String, Object> asMap = new LinkedHashMap<>();
67+
asMap.put(NAME, warning);
68+
return asMap;
6069
}
6170

6271
@Override

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationEvaluationIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,4 +649,9 @@ private static void indexDistinctAnimals(String indexName, int distinctAnimalCou
649649
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
650650
}
651651
}
652+
653+
@Override
654+
boolean supportsInference() {
655+
return true;
656+
}
652657
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ClassificationIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,4 +893,9 @@ private String stateDocId() {
893893
private String expectedDestIndexAuditMessage() {
894894
return (analysisUsesExistingDestIndex ? "Using existing" : "Creating") + " destination index [" + destIndex + "]";
895895
}
896+
897+
@Override
898+
boolean supportsInference() {
899+
return true;
900+
}
896901
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/ExplainDataFrameAnalyticsIT.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ public void testTrainingPercentageIsApplied() throws IOException {
123123

124124
explainResponse = explainDataFrame(config);
125125

126-
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk(),
126+
assertThat(explainResponse.getMemoryEstimation().getExpectedMemoryWithoutDisk(),
127127
lessThanOrEqualTo(allDataUsedForTraining));
128128
}
129+
130+
@Override
131+
boolean supportsInference() {
132+
return false;
133+
}
129134
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/MlNativeDataFrameAnalyticsIntegTestCase.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ protected void assertProgressComplete(String id) {
219219
progress.stream().allMatch(phaseProgress -> phaseProgress.getProgressPercent() == 100), is(true));
220220
}
221221

222+
abstract boolean supportsInference();
223+
222224
private List<PhaseProgress> getProgress(String id) {
223225
GetDataFrameAnalyticsStatsAction.Response.Stats stats = getAnalyticsStats(id);
224226
assertThat(stats.getId(), equalTo(id));
@@ -227,7 +229,12 @@ private List<PhaseProgress> getProgress(String id) {
227229
assertThat(progress.size(), greaterThanOrEqualTo(4));
228230
assertThat(progress.get(0).getPhase(), equalTo("reindexing"));
229231
assertThat(progress.get(1).getPhase(), equalTo("loading_data"));
230-
assertThat(progress.get(progress.size() - 1).getPhase(), equalTo("writing_results"));
232+
if (supportsInference()) {
233+
assertThat(progress.get(progress.size() - 2).getPhase(), equalTo("writing_results"));
234+
assertThat(progress.get(progress.size() - 1).getPhase(), equalTo("inference"));
235+
} else {
236+
assertThat(progress.get(progress.size() - 1).getPhase(), equalTo("writing_results"));
237+
}
231238
return progress;
232239
}
233240

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/OutlierDetectionWithMissingFieldsIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,9 @@ public void testMissingFields() throws Exception {
111111
assertProgressComplete(id);
112112
assertThat(searchStoredProgress(id).getHits().getTotalHits().value, equalTo(1L));
113113
}
114+
115+
@Override
116+
boolean supportsInference() {
117+
return false;
118+
}
114119
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionEvaluationIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,4 +155,9 @@ private static void indexHousesData(String indexName) {
155155
fail("Failed to index data: " + bulkResponse.buildFailureMessage());
156156
}
157157
}
158+
159+
@Override
160+
boolean supportsInference() {
161+
return true;
162+
}
158163
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RegressionIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,4 +540,9 @@ private static Map<String, Object> getMlResultsObjectFromDestDoc(Map<String, Obj
540540
protected String stateDocId() {
541541
return jobId + "_regression_state#1";
542542
}
543+
544+
@Override
545+
boolean supportsInference() {
546+
return true;
547+
}
543548
}

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/RunDataFrameAnalyticsIT.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,4 +745,9 @@ public void testOutlierDetectionWithCustomParams() throws Exception {
745745
"Started writing results",
746746
"Finished analysis");
747747
}
748+
749+
@Override
750+
boolean supportsInference() {
751+
return false;
752+
}
748753
}

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,13 @@ public Collection<Object> createComponents(Client client, ClusterService cluster
693693
this.modelLoadingService.set(modelLoadingService);
694694

695695
// Data frame analytics components
696-
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client, threadPool, analyticsProcessFactory,
697-
dataFrameAnalyticsAuditor, trainedModelProvider, resultsPersisterService);
696+
AnalyticsProcessManager analyticsProcessManager = new AnalyticsProcessManager(client,
697+
threadPool,
698+
analyticsProcessFactory,
699+
dataFrameAnalyticsAuditor,
700+
trainedModelProvider,
701+
modelLoadingService,
702+
resultsPersisterService);
698703
MemoryUsageEstimationProcessManager memoryEstimationProcessManager =
699704
new MemoryUsageEstimationProcessManager(
700705
threadPool.generic(), threadPool.executor(MachineLearning.JOB_COMMS_THREAD_POOL_NAME), memoryEstimationProcessFactory);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDataFrameAnalyticsStatsAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ private void searchStats(DataFrameAnalyticsConfig config, ActionListener<Stats>
195195
logger.debug("[{}] Gathering stats for stopped task", config.getId());
196196

197197
RetrievedStatsHolder retrievedStatsHolder = new RetrievedStatsHolder(
198-
ProgressTracker.fromZeroes(config.getAnalysis().getProgressPhases()).report());
198+
ProgressTracker.fromZeroes(config.getAnalysis().getProgressPhases(), config.getAnalysis().supportsInference()).report());
199199

200200
MultiSearchRequest multiSearchRequest = new MultiSearchRequest();
201201
multiSearchRequest.add(buildStoredProgressSearch(config.getId()));

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DataFrameAnalyticsManager.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ public void execute(DataFrameAnalyticsTask task, DataFrameAnalyticsState current
8484
// At this point we have the config at hand and we can reset the progress tracker
8585
// to use the analyses phases. We preserve reindexing progress as if reindexing was
8686
// finished it will not be reset.
87-
task.getStatsHolder().resetProgressTrackerPreservingReindexingProgress(config.getAnalysis().getProgressPhases());
87+
task.getStatsHolder().resetProgressTrackerPreservingReindexingProgress(config.getAnalysis().getProgressPhases(),
88+
config.getAnalysis().supportsInference());
8889

8990
switch(currentState) {
9091
// If we are STARTED, it means the job was started because the start API was called.

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/DestinationIndex.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ public final class DestinationIndex {
4949

5050
public static final String ID_COPY = "ml__id_copy";
5151

52+
/**
53+
* The field that indicates whether a doc was used for training or not
54+
*/
55+
public static final String IS_TRAINING = "is_training";
56+
5257
// Metadata fields
5358
static final String CREATION_DATE_MILLIS = "creation_date_in_millis";
5459
static final String VERSION = "version";

0 commit comments

Comments
 (0)