Skip to content

Commit 7667004

Browse files
authored
[ML] Add a model memory estimation endpoint for anomaly detection (#54129)
A new endpoint for estimating anomaly detection job model memory requirements: POST _ml/anomaly_detectors/estimate_model_memory Backport of #53507
1 parent 7c0123d commit 7667004

File tree

14 files changed

+669
-31
lines changed

14 files changed

+669
-31
lines changed

client/rest-high-level/src/main/java/org/elasticsearch/client/MLRequestConverters.java

+12
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
import org.elasticsearch.client.ml.DeleteJobRequest;
4141
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
4242
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
43+
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
4344
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
4445
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
4546
import org.elasticsearch.client.ml.FindFileStructureRequest;
@@ -593,6 +594,17 @@ static Request deleteCalendarEvent(DeleteCalendarEventRequest deleteCalendarEven
593594
return new Request(HttpDelete.METHOD_NAME, endpoint);
594595
}
595596

597+
static Request estimateModelMemory(EstimateModelMemoryRequest estimateModelMemoryRequest) throws IOException {
598+
String endpoint = new EndpointBuilder()
599+
.addPathPartAsIs("_ml")
600+
.addPathPartAsIs("anomaly_detectors")
601+
.addPathPartAsIs("_estimate_model_memory")
602+
.build();
603+
Request request = new Request(HttpPost.METHOD_NAME, endpoint);
604+
request.setEntity(createEntity(estimateModelMemoryRequest, REQUEST_BODY_CONTENT_TYPE));
605+
return request;
606+
}
607+
596608
static Request putDataFrameAnalytics(PutDataFrameAnalyticsRequest putRequest) throws IOException {
597609
String endpoint = new EndpointBuilder()
598610
.addPathPartAsIs("_ml", "data_frame", "analytics")

client/rest-high-level/src/main/java/org/elasticsearch/client/MachineLearningClient.java

+44
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.elasticsearch.client.ml.CloseJobRequest;
2424
import org.elasticsearch.client.ml.CloseJobResponse;
2525
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
26+
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
27+
import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
2628
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
2729
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsResponse;
2830
import org.elasticsearch.client.ml.DeleteCalendarEventRequest;
@@ -1951,6 +1953,48 @@ public Cancellable setUpgradeModeAsync(SetUpgradeModeRequest request, RequestOpt
19511953
Collections.emptySet());
19521954
}
19531955

1956+
/**
1957+
* Estimate the model memory an analysis config is likely to need given supplied field cardinalities
1958+
* <p>
1959+
* For additional info
1960+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
1961+
*
1962+
* @param request The {@link EstimateModelMemoryRequest}
1963+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
1964+
* @return {@link EstimateModelMemoryResponse} response object
1965+
*/
1966+
public EstimateModelMemoryResponse estimateModelMemory(EstimateModelMemoryRequest request,
1967+
RequestOptions options) throws IOException {
1968+
return restHighLevelClient.performRequestAndParseEntity(request,
1969+
MLRequestConverters::estimateModelMemory,
1970+
options,
1971+
EstimateModelMemoryResponse::fromXContent,
1972+
Collections.emptySet());
1973+
}
1974+
1975+
/**
1976+
* Estimate the model memory an analysis config is likely to need given supplied field cardinalities and notifies listener upon
1977+
* completion
1978+
* <p>
1979+
* For additional info
1980+
* see <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/ml-estimate-model-memory.html">Estimate Model Memory</a>
1981+
*
1982+
* @param request The {@link EstimateModelMemoryRequest}
1983+
* @param options Additional request options (e.g. headers), use {@link RequestOptions#DEFAULT} if nothing needs to be customized
1984+
* @param listener Listener to be notified upon request completion
1985+
* @return cancellable that may be used to cancel the request
1986+
*/
1987+
public Cancellable estimateModelMemoryAsync(EstimateModelMemoryRequest request,
1988+
RequestOptions options,
1989+
ActionListener<EstimateModelMemoryResponse> listener) {
1990+
return restHighLevelClient.performRequestAsyncAndParseEntity(request,
1991+
MLRequestConverters::estimateModelMemory,
1992+
options,
1993+
EstimateModelMemoryResponse::fromXContent,
1994+
listener,
1995+
Collections.emptySet());
1996+
}
1997+
19541998
/**
19551999
* Creates a new Data Frame Analytics config
19562000
* <p>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml;
21+
22+
import org.elasticsearch.client.Validatable;
23+
import org.elasticsearch.client.ValidationException;
24+
import org.elasticsearch.client.ml.job.config.AnalysisConfig;
25+
import org.elasticsearch.common.xcontent.ToXContentObject;
26+
import org.elasticsearch.common.xcontent.XContentBuilder;
27+
28+
import java.io.IOException;
29+
import java.util.Collections;
30+
import java.util.Map;
31+
import java.util.Objects;
32+
import java.util.Optional;
33+
34+
/**
35+
* Request to estimate the model memory an analysis config is likely to need given supplied field cardinalities.
36+
*/
37+
public class EstimateModelMemoryRequest implements Validatable, ToXContentObject {
38+
39+
public static final String ANALYSIS_CONFIG = "analysis_config";
40+
public static final String OVERALL_CARDINALITY = "overall_cardinality";
41+
public static final String MAX_BUCKET_CARDINALITY = "max_bucket_cardinality";
42+
43+
private final AnalysisConfig analysisConfig;
44+
private Map<String, Long> overallCardinality = Collections.emptyMap();
45+
private Map<String, Long> maxBucketCardinality = Collections.emptyMap();
46+
47+
@Override
48+
public Optional<ValidationException> validate() {
49+
return Optional.empty();
50+
}
51+
52+
public EstimateModelMemoryRequest(AnalysisConfig analysisConfig) {
53+
this.analysisConfig = Objects.requireNonNull(analysisConfig);
54+
}
55+
56+
public AnalysisConfig getAnalysisConfig() {
57+
return analysisConfig;
58+
}
59+
60+
public Map<String, Long> getOverallCardinality() {
61+
return overallCardinality;
62+
}
63+
64+
public void setOverallCardinality(Map<String, Long> overallCardinality) {
65+
this.overallCardinality = Collections.unmodifiableMap(overallCardinality);
66+
}
67+
68+
public Map<String, Long> getMaxBucketCardinality() {
69+
return maxBucketCardinality;
70+
}
71+
72+
public void setMaxBucketCardinality(Map<String, Long> maxBucketCardinality) {
73+
this.maxBucketCardinality = Collections.unmodifiableMap(maxBucketCardinality);
74+
}
75+
76+
@Override
77+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
78+
builder.startObject();
79+
builder.field(ANALYSIS_CONFIG, analysisConfig);
80+
if (overallCardinality.isEmpty() == false) {
81+
builder.field(OVERALL_CARDINALITY, overallCardinality);
82+
}
83+
if (maxBucketCardinality.isEmpty() == false) {
84+
builder.field(MAX_BUCKET_CARDINALITY, maxBucketCardinality);
85+
}
86+
builder.endObject();
87+
return builder;
88+
}
89+
90+
@Override
91+
public int hashCode() {
92+
return Objects.hash(analysisConfig, overallCardinality, maxBucketCardinality);
93+
}
94+
95+
@Override
96+
public boolean equals(Object other) {
97+
if (this == other) {
98+
return true;
99+
}
100+
101+
if (other == null || getClass() != other.getClass()) {
102+
return false;
103+
}
104+
105+
EstimateModelMemoryRequest that = (EstimateModelMemoryRequest) other;
106+
return Objects.equals(analysisConfig, that.analysisConfig) &&
107+
Objects.equals(overallCardinality, that.overallCardinality) &&
108+
Objects.equals(maxBucketCardinality, that.maxBucketCardinality);
109+
}
110+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Licensed to Elasticsearch under one or more contributor
3+
* license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright
5+
* ownership. Elasticsearch licenses this file to you under
6+
* the Apache License, Version 2.0 (the "License"); you may
7+
* not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.elasticsearch.client.ml;
21+
22+
import org.elasticsearch.common.ParseField;
23+
import org.elasticsearch.common.unit.ByteSizeValue;
24+
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
25+
import org.elasticsearch.common.xcontent.XContentParser;
26+
27+
import java.util.Objects;
28+
29+
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
30+
31+
public class EstimateModelMemoryResponse {
32+
33+
public static final ParseField MODEL_MEMORY_ESTIMATE = new ParseField("model_memory_estimate");
34+
35+
static final ConstructingObjectParser<EstimateModelMemoryResponse, Void> PARSER =
36+
new ConstructingObjectParser<>(
37+
"estimate_model_memory",
38+
true,
39+
args -> new EstimateModelMemoryResponse((String) args[0]));
40+
41+
static {
42+
PARSER.declareString(constructorArg(), MODEL_MEMORY_ESTIMATE);
43+
}
44+
45+
public static EstimateModelMemoryResponse fromXContent(final XContentParser parser) {
46+
return PARSER.apply(parser, null);
47+
}
48+
49+
private final ByteSizeValue modelMemoryEstimate;
50+
51+
public EstimateModelMemoryResponse(String modelMemoryEstimate) {
52+
this.modelMemoryEstimate = ByteSizeValue.parseBytesSizeValue(modelMemoryEstimate, MODEL_MEMORY_ESTIMATE.getPreferredName());
53+
}
54+
55+
/**
56+
* @return An estimate of the model memory the supplied analysis config is likely to need given the supplied field cardinalities.
57+
*/
58+
public ByteSizeValue getModelMemoryEstimate() {
59+
return modelMemoryEstimate;
60+
}
61+
62+
@Override
63+
public boolean equals(Object o) {
64+
65+
if (this == o) {
66+
return true;
67+
}
68+
if (o == null || getClass() != o.getClass()) {
69+
return false;
70+
}
71+
72+
EstimateModelMemoryResponse other = (EstimateModelMemoryResponse) o;
73+
return Objects.equals(this.modelMemoryEstimate, other.modelMemoryEstimate);
74+
}
75+
76+
@Override
77+
public int hashCode() {
78+
return Objects.hash(modelMemoryEstimate);
79+
}
80+
}

client/rest-high-level/src/test/java/org/elasticsearch/client/MLRequestConvertersTests.java

+21
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.elasticsearch.client.ml.DeleteJobRequest;
3737
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
3838
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
39+
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
3940
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
4041
import org.elasticsearch.client.ml.EvaluateDataFrameRequestTests;
4142
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
@@ -107,6 +108,7 @@
107108
import org.elasticsearch.common.settings.Settings;
108109
import org.elasticsearch.common.unit.TimeValue;
109110
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
111+
import org.elasticsearch.common.xcontent.ToXContent;
110112
import org.elasticsearch.common.xcontent.XContentBuilder;
111113
import org.elasticsearch.common.xcontent.XContentParser;
112114
import org.elasticsearch.common.xcontent.XContentType;
@@ -695,6 +697,25 @@ public void testDeleteCalendarEvent() {
695697
assertEquals("/_ml/calendars/" + calendarId + "/events/" + eventId, request.getEndpoint());
696698
}
697699

700+
public void testEstimateModelMemory() throws Exception {
701+
String byFieldName = randomAlphaOfLength(10);
702+
String influencerFieldName = randomAlphaOfLength(10);
703+
AnalysisConfig analysisConfig = AnalysisConfig.builder(
704+
Collections.singletonList(
705+
Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
706+
)).setInfluencers(Collections.singletonList(influencerFieldName)).build();
707+
EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
708+
estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
709+
estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));
710+
Request request = MLRequestConverters.estimateModelMemory(estimateModelMemoryRequest);
711+
assertEquals(HttpPost.METHOD_NAME, request.getMethod());
712+
assertEquals("/_ml/anomaly_detectors/_estimate_model_memory", request.getEndpoint());
713+
714+
XContentBuilder builder = JsonXContent.contentBuilder();
715+
builder = estimateModelMemoryRequest.toXContent(builder, ToXContent.EMPTY_PARAMS);
716+
assertEquals(Strings.toString(builder), requestEntityToString(request));
717+
}
718+
698719
public void testPutDataFrameAnalytics() throws IOException {
699720
PutDataFrameAnalyticsRequest putRequest = new PutDataFrameAnalyticsRequest(randomDataFrameAnalyticsConfig());
700721
Request request = MLRequestConverters.putDataFrameAnalytics(putRequest);

client/rest-high-level/src/test/java/org/elasticsearch/client/MachineLearningIT.java

+23
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646
import org.elasticsearch.client.ml.DeleteJobResponse;
4747
import org.elasticsearch.client.ml.DeleteModelSnapshotRequest;
4848
import org.elasticsearch.client.ml.DeleteTrainedModelRequest;
49+
import org.elasticsearch.client.ml.EstimateModelMemoryRequest;
50+
import org.elasticsearch.client.ml.EstimateModelMemoryResponse;
4951
import org.elasticsearch.client.ml.EvaluateDataFrameRequest;
5052
import org.elasticsearch.client.ml.EvaluateDataFrameResponse;
5153
import org.elasticsearch.client.ml.ExplainDataFrameAnalyticsRequest;
@@ -1274,6 +1276,27 @@ public void testDeleteCalendarEvent() throws IOException {
12741276
assertThat(remainingIds, not(hasItem(deletedEvent)));
12751277
}
12761278

1279+
public void testEstimateModelMemory() throws Exception {
1280+
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
1281+
1282+
String byFieldName = randomAlphaOfLength(10);
1283+
String influencerFieldName = randomAlphaOfLength(10);
1284+
AnalysisConfig analysisConfig = AnalysisConfig.builder(
1285+
Collections.singletonList(
1286+
Detector.builder().setFunction("count").setByFieldName(byFieldName).build()
1287+
)).setInfluencers(Collections.singletonList(influencerFieldName)).build();
1288+
EstimateModelMemoryRequest estimateModelMemoryRequest = new EstimateModelMemoryRequest(analysisConfig);
1289+
estimateModelMemoryRequest.setOverallCardinality(Collections.singletonMap(byFieldName, randomNonNegativeLong()));
1290+
estimateModelMemoryRequest.setMaxBucketCardinality(Collections.singletonMap(influencerFieldName, randomNonNegativeLong()));
1291+
1292+
EstimateModelMemoryResponse estimateModelMemoryResponse = execute(
1293+
estimateModelMemoryRequest,
1294+
machineLearningClient::estimateModelMemory, machineLearningClient::estimateModelMemoryAsync);
1295+
1296+
ByteSizeValue modelMemoryEstimate = estimateModelMemoryResponse.getModelMemoryEstimate();
1297+
assertThat(modelMemoryEstimate.getBytes(), greaterThanOrEqualTo(10000000L));
1298+
}
1299+
12771300
public void testPutDataFrameAnalyticsConfig_GivenOutlierDetectionAnalysis() throws Exception {
12781301
MachineLearningClient machineLearningClient = highLevelClient().machineLearning();
12791302
String configId = "test-put-df-analytics-outlier-detection";

0 commit comments

Comments
 (0)