Skip to content

Commit c6d977c

Browse files
authored
[ML][Inference] Adding model memory estimations (#48323)
* [ML][Inference] Adding model memory estimations * addressing PR comments
1 parent 18ea144 commit c6d977c

File tree

29 files changed

+734
-263
lines changed

29 files changed

+734
-263
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
174174
}
175175
builder.timeField(CREATE_TIME.getPreferredName(), CREATE_TIME.getPreferredName() + "_string", createTime.toEpochMilli());
176176
if (definition != null) {
177-
builder.field(DEFINITION.getPreferredName(), definition);
177+
builder.field(DEFINITION.getPreferredName(), definition, params);
178178
}
179179
builder.field(TAGS.getPreferredName(), tags);
180180
if (metadata != null) {

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelDefinition.java

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference;
77

8+
import org.apache.lucene.util.Accountable;
9+
import org.apache.lucene.util.Accountables;
10+
import org.apache.lucene.util.RamUsageEstimator;
811
import org.elasticsearch.common.ParseField;
912
import org.elasticsearch.common.Strings;
1013
import org.elasticsearch.common.io.stream.StreamInput;
1114
import org.elasticsearch.common.io.stream.StreamOutput;
1215
import org.elasticsearch.common.io.stream.Writeable;
16+
import org.elasticsearch.common.unit.ByteSizeValue;
1317
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
1418
import org.elasticsearch.common.xcontent.ObjectParser;
1519
import org.elasticsearch.common.xcontent.ToXContentObject;
@@ -25,16 +29,21 @@
2529
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TrainedModel;
2630
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
2731
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
32+
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
2833

2934
import java.io.IOException;
35+
import java.util.ArrayList;
36+
import java.util.Collection;
3037
import java.util.Collections;
3138
import java.util.List;
3239
import java.util.Map;
3340
import java.util.Objects;
3441

35-
public class TrainedModelDefinition implements ToXContentObject, Writeable {
42+
public class TrainedModelDefinition implements ToXContentObject, Writeable, Accountable {
3643

44+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TrainedModelDefinition.class);
3745
public static final String NAME = "trained_mode_definition";
46+
public static final String HEAP_MEMORY_ESTIMATION = "heap_memory_estimation";
3847

3948
public static final ParseField TRAINED_MODEL = new ParseField("trained_model");
4049
public static final ParseField PREPROCESSORS = new ParseField("preprocessors");
@@ -105,6 +114,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
105114
PREPROCESSORS.getPreferredName(),
106115
preProcessors);
107116
builder.field(INPUT.getPreferredName(), input);
117+
if (params.paramAsBoolean(ToXContentParams.FOR_INTERNAL_STORAGE, false) == false) {
118+
builder.humanReadableField(HEAP_MEMORY_ESTIMATION + "_bytes",
119+
HEAP_MEMORY_ESTIMATION,
120+
new ByteSizeValue(ramBytesUsed()));
121+
}
108122
builder.endObject();
109123
return builder;
110124
}
@@ -150,6 +164,26 @@ public int hashCode() {
150164
return Objects.hash(trainedModel, input, preProcessors);
151165
}
152166

167+
@Override
168+
public long ramBytesUsed() {
169+
long size = SHALLOW_SIZE;
170+
size += RamUsageEstimator.sizeOf(trainedModel);
171+
size += RamUsageEstimator.sizeOf(input);
172+
size += RamUsageEstimator.sizeOfCollection(preProcessors);
173+
return size;
174+
}
175+
176+
@Override
177+
public Collection<Accountable> getChildResources() {
178+
List<Accountable> accountables = new ArrayList<>(preProcessors.size() + 2);
179+
accountables.add(Accountables.namedAccountable("input", input));
180+
accountables.add(Accountables.namedAccountable("trained_model", trainedModel));
181+
for(PreProcessor preProcessor : preProcessors) {
182+
accountables.add(Accountables.namedAccountable("pre_processor_" + preProcessor.getName(), preProcessor));
183+
}
184+
return accountables;
185+
}
186+
153187
public static class Builder {
154188

155189
private List<PreProcessor> preProcessors;
@@ -204,8 +238,9 @@ public TrainedModelDefinition build() {
204238
}
205239
}
206240

207-
public static class Input implements ToXContentObject, Writeable {
241+
public static class Input implements ToXContentObject, Writeable, Accountable {
208242

243+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Input.class);
209244
public static final String NAME = "trained_mode_definition_input";
210245
public static final ParseField FIELD_NAMES = new ParseField("field_names");
211246

@@ -265,6 +300,15 @@ public int hashCode() {
265300
return Objects.hash(fieldNames);
266301
}
267302

303+
@Override
304+
public long ramBytesUsed() {
305+
return SHALLOW_SIZE + RamUsageEstimator.sizeOfCollection(fieldNames);
306+
}
307+
308+
@Override
309+
public String toString() {
310+
return Strings.toString(this);
311+
}
268312
}
269313

270314
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/FrequencyEncoding.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
77

8+
import org.apache.lucene.util.RamUsageEstimator;
89
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.Strings;
911
import org.elasticsearch.common.io.stream.StreamInput;
1012
import org.elasticsearch.common.io.stream.StreamOutput;
1113
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -25,6 +27,8 @@
2527
*/
2628
public class FrequencyEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
2729

30+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(FrequencyEncoding.class);
31+
2832
public static final ParseField NAME = new ParseField("frequency_encoding");
2933
public static final ParseField FIELD = new ParseField("field");
3034
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
@@ -143,4 +147,17 @@ public int hashCode() {
143147
return Objects.hash(field, featureName, frequencyMap);
144148
}
145149

150+
@Override
151+
public long ramBytesUsed() {
152+
long size = SHALLOW_SIZE;
153+
size += RamUsageEstimator.sizeOf(field);
154+
size += RamUsageEstimator.sizeOf(featureName);
155+
size += RamUsageEstimator.sizeOfMap(frequencyMap);
156+
return size;
157+
}
158+
159+
@Override
160+
public String toString() {
161+
return Strings.toString(this);
162+
}
146163
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/OneHotEncoding.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
77

8+
import org.apache.lucene.util.RamUsageEstimator;
89
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.Strings;
911
import org.elasticsearch.common.io.stream.StreamInput;
1012
import org.elasticsearch.common.io.stream.StreamOutput;
1113
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -23,6 +25,7 @@
2325
*/
2426
public class OneHotEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
2527

28+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(OneHotEncoding.class);
2629
public static final ParseField NAME = new ParseField("one_hot_encoding");
2730
public static final ParseField FIELD = new ParseField("field");
2831
public static final ParseField HOT_MAP = new ParseField("hot_map");
@@ -127,4 +130,16 @@ public int hashCode() {
127130
return Objects.hash(field, hotMap);
128131
}
129132

133+
@Override
134+
public long ramBytesUsed() {
135+
long size = SHALLOW_SIZE;
136+
size += RamUsageEstimator.sizeOf(field);
137+
size += RamUsageEstimator.sizeOfMap(hotMap);
138+
return size;
139+
}
140+
141+
@Override
142+
public String toString() {
143+
return Strings.toString(this);
144+
}
130145
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/PreProcessor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
77

8+
import org.apache.lucene.util.Accountable;
89
import org.elasticsearch.common.io.stream.NamedWriteable;
910
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject;
1011

@@ -14,7 +15,7 @@
1415
* Describes a pre-processor for a defined machine learning model
1516
* This processor should take a set of fields and return the modified set of fields.
1617
*/
17-
public interface PreProcessor extends NamedXContentObject, NamedWriteable {
18+
public interface PreProcessor extends NamedXContentObject, NamedWriteable, Accountable {
1819

1920
/**
2021
* Process the given fields and their values and return the modified map.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/preprocessing/TargetMeanEncoding.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.preprocessing;
77

8+
import org.apache.lucene.util.RamUsageEstimator;
89
import org.elasticsearch.common.ParseField;
10+
import org.elasticsearch.common.Strings;
911
import org.elasticsearch.common.io.stream.StreamInput;
1012
import org.elasticsearch.common.io.stream.StreamOutput;
1113
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@@ -25,6 +27,7 @@
2527
*/
2628
public class TargetMeanEncoding implements LenientlyParsedPreProcessor, StrictlyParsedPreProcessor {
2729

30+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(TargetMeanEncoding.class);
2831
public static final ParseField NAME = new ParseField("target_mean_encoding");
2932
public static final ParseField FIELD = new ParseField("field");
3033
public static final ParseField FEATURE_NAME = new ParseField("feature_name");
@@ -158,4 +161,17 @@ public int hashCode() {
158161
return Objects.hash(field, featureName, meanMap, defaultValue);
159162
}
160163

164+
@Override
165+
public long ramBytesUsed() {
166+
long size = SHALLOW_SIZE;
167+
size += RamUsageEstimator.sizeOf(field);
168+
size += RamUsageEstimator.sizeOf(featureName);
169+
size += RamUsageEstimator.sizeOfMap(meanMap);
170+
return size;
171+
}
172+
173+
@Override
174+
public String toString() {
175+
return Strings.toString(this);
176+
}
161177
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TrainedModel.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;
77

8+
import org.apache.lucene.util.Accountable;
89
import org.elasticsearch.common.Nullable;
910
import org.elasticsearch.common.io.stream.NamedWriteable;
1011
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
@@ -13,7 +14,7 @@
1314
import java.util.List;
1415
import java.util.Map;
1516

16-
public interface TrainedModel extends NamedXContentObject, NamedWriteable {
17+
public interface TrainedModel extends NamedXContentObject, NamedWriteable, Accountable {
1718

1819
/**
1920
* @return List of featureNames expected by the model. In the order that they are expected

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/Ensemble.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
*/
66
package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;
77

8+
import org.apache.lucene.util.Accountable;
9+
import org.apache.lucene.util.Accountables;
10+
import org.apache.lucene.util.RamUsageEstimator;
811
import org.elasticsearch.common.Nullable;
912
import org.elasticsearch.common.ParseField;
1013
import org.elasticsearch.common.io.stream.StreamInput;
@@ -29,6 +32,8 @@
2932
import org.elasticsearch.xpack.core.ml.utils.NamedXContentObjectHelper;
3033

3134
import java.io.IOException;
35+
import java.util.ArrayList;
36+
import java.util.Collection;
3237
import java.util.Collections;
3338
import java.util.List;
3439
import java.util.Map;
@@ -39,6 +44,7 @@
3944

4045
public class Ensemble implements LenientlyParsedTrainedModel, StrictlyParsedTrainedModel {
4146

47+
private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(Ensemble.class);
4248
// TODO should we have regression/classification sub-classes that accept the builder?
4349
public static final ParseField NAME = new ParseField("ensemble");
4450
public static final ParseField FEATURE_NAMES = new ParseField("feature_names");
@@ -249,6 +255,26 @@ public static Builder builder() {
249255
return new Builder();
250256
}
251257

258+
@Override
259+
public long ramBytesUsed() {
260+
long size = SHALLOW_SIZE;
261+
size += RamUsageEstimator.sizeOfCollection(featureNames);
262+
size += RamUsageEstimator.sizeOfCollection(classificationLabels);
263+
size += RamUsageEstimator.sizeOfCollection(models);
264+
size += outputAggregator.ramBytesUsed();
265+
return size;
266+
}
267+
268+
@Override
269+
public Collection<Accountable> getChildResources() {
270+
List<Accountable> accountables = new ArrayList<>(models.size() + 1);
271+
for (TrainedModel model : models) {
272+
accountables.add(Accountables.namedAccountable(model.getName(), model));
273+
}
274+
accountables.add(Accountables.namedAccountable(outputAggregator.getName(), outputAggregator));
275+
return Collections.unmodifiableCollection(accountables);
276+
}
277+
252278
public static class Builder {
253279
private List<String> featureNames;
254280
private List<TrainedModel> trainedModels;

0 commit comments

Comments
 (0)