Skip to content

[ML] Accommodate changes to calculation of "model bytes" and "peak mode bytes" values.. #126256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/126256.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126256
summary: Handle new `actual_memory_usage_bytes` field in model size stats
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_AD_ACTUAL_MEMORY_USAGE = def(9_048_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,12 +153,13 @@ private static void checkUniqueness(int id, String uniqueId) {
// V_11 is used in ELSER v2 package configs
public static final MlConfigVersion V_11 = registerMlConfigVersion(11_00_0_0_99, "79CB2950-57C7-11EE-AE5D-0800200C9A66");
public static final MlConfigVersion V_12 = registerMlConfigVersion(12_00_0_0_99, "Trained model config prefix strings added");
public static final MlConfigVersion V_13 = registerMlConfigVersion(13_00_0_0_99, "Anomaly Detection reports actual memory usage");

/**
* Reference to the most recent Ml config version.
* This should be the Ml config version with the highest id.
*/
public static final MlConfigVersion CURRENT = V_12;
public static final MlConfigVersion CURRENT = V_13;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the config version change necessary? Model size stats are considered to be a result type not a config type and are stored in the .ml-anomalies-x indices.

ModelSizeStats has 2 parsers; a strict parser that errors if it does not recognise a field and a lenient parser that ignores unknown fields. The strict parser is used reading the output from autodetect and the lenient is used reading the documents from Elasticsearch. This way in a mixed version cluster if a upgraded node stores the stats doc with the actual_memory_usage_bytes field then when an old node reads the doc it won't error when it sees the unknown actual_memory_usage_bytes field. This mechanism works well for backwards compatibility and means that most of the time versioning is not required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TIL, Thanks Dave!


/**
* Reference to the first MlConfigVersion that is detached from the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
*/
public static final ParseField MODEL_BYTES_FIELD = new ParseField("model_bytes");
public static final ParseField PEAK_MODEL_BYTES_FIELD = new ParseField("peak_model_bytes");
public static final ParseField ACTUAL_MEMORY_USAGE_BYTES = new ParseField("actual_memory_usage_bytes");
public static final ParseField MODEL_BYTES_EXCEEDED_FIELD = new ParseField("model_bytes_exceeded");
public static final ParseField MODEL_BYTES_MEMORY_LIMIT_FIELD = new ParseField("model_bytes_memory_limit");
public static final ParseField TOTAL_BY_FIELD_COUNT_FIELD = new ParseField("total_by_field_count");
Expand Down Expand Up @@ -74,6 +75,7 @@ private static ConstructingObjectParser<Builder, Void> createParser(boolean igno
parser.declareString((modelSizeStat, s) -> {}, Result.RESULT_TYPE);
parser.declareLong(Builder::setModelBytes, MODEL_BYTES_FIELD);
parser.declareLong(Builder::setPeakModelBytes, PEAK_MODEL_BYTES_FIELD);
parser.declareLong(Builder::setActualMemoryUsageBytes, ACTUAL_MEMORY_USAGE_BYTES);
parser.declareLong(Builder::setModelBytesExceeded, MODEL_BYTES_EXCEEDED_FIELD);
parser.declareLong(Builder::setModelBytesMemoryLimit, MODEL_BYTES_MEMORY_LIMIT_FIELD);
parser.declareLong(Builder::setBucketAllocationFailuresCount, BUCKET_ALLOCATION_FAILURES_COUNT_FIELD);
Expand Down Expand Up @@ -152,14 +154,16 @@ public String toString() {
* 1. The job's model_memory_limit
* 2. The current model memory, i.e. what's reported in model_bytes of this object
* 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object
* 4. The actual memory usage, i.e. what's reported in actual_memory_usage_bytes of this object
* The field storing this enum can also be <code>null</code>, which means the
* assignment code will decide on the fly - this was the old behaviour prior
* to 7.11.
*/
public enum AssignmentMemoryBasis implements Writeable {
MODEL_MEMORY_LIMIT,
CURRENT_MODEL_BYTES,
PEAK_MODEL_BYTES;
PEAK_MODEL_BYTES,
ACTUAL_MEMORY_USAGE_BYTES;

public static AssignmentMemoryBasis fromString(String statusName) {
return valueOf(statusName.trim().toUpperCase(Locale.ROOT));
Expand All @@ -183,6 +187,7 @@ public String toString() {
private final String jobId;
private final long modelBytes;
private final Long peakModelBytes;
private final Long actualMemoryUsageBytes;
private final Long modelBytesExceeded;
private final Long modelBytesMemoryLimit;
private final long totalByFieldCount;
Expand All @@ -206,6 +211,7 @@ private ModelSizeStats(
String jobId,
long modelBytes,
Long peakModelBytes,
Long actualMemoryUsageBytes,
Long modelBytesExceeded,
Long modelBytesMemoryLimit,
long totalByFieldCount,
Expand All @@ -228,6 +234,7 @@ private ModelSizeStats(
this.jobId = jobId;
this.modelBytes = modelBytes;
this.peakModelBytes = peakModelBytes;
this.actualMemoryUsageBytes = actualMemoryUsageBytes;
this.modelBytesExceeded = modelBytesExceeded;
this.modelBytesMemoryLimit = modelBytesMemoryLimit;
this.totalByFieldCount = totalByFieldCount;
Expand All @@ -252,6 +259,11 @@ public ModelSizeStats(StreamInput in) throws IOException {
jobId = in.readString();
modelBytes = in.readVLong();
peakModelBytes = in.readOptionalLong();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) {
actualMemoryUsageBytes = in.readOptionalLong();
} else {
actualMemoryUsageBytes = null;
}
modelBytesExceeded = in.readOptionalLong();
modelBytesMemoryLimit = in.readOptionalLong();
totalByFieldCount = in.readVLong();
Expand Down Expand Up @@ -293,6 +305,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeVLong(modelBytes);
out.writeOptionalLong(peakModelBytes);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) {
out.writeOptionalLong(actualMemoryUsageBytes);
}
out.writeOptionalLong(modelBytesExceeded);
out.writeOptionalLong(modelBytesMemoryLimit);
out.writeVLong(totalByFieldCount);
Expand Down Expand Up @@ -339,6 +354,9 @@ public XContentBuilder doXContentBody(XContentBuilder builder) throws IOExceptio
if (peakModelBytes != null) {
builder.field(PEAK_MODEL_BYTES_FIELD.getPreferredName(), peakModelBytes);
}
if (actualMemoryUsageBytes != null) {
builder.field(ACTUAL_MEMORY_USAGE_BYTES.getPreferredName(), actualMemoryUsageBytes);
}
if (modelBytesExceeded != null) {
builder.field(MODEL_BYTES_EXCEEDED_FIELD.getPreferredName(), modelBytesExceeded);
}
Expand Down Expand Up @@ -391,6 +409,10 @@ public Long getPeakModelBytes() {
return peakModelBytes;
}

public Long getActualMemoryUsageBytes() {
return actualMemoryUsageBytes;
}

public Long getModelBytesExceeded() {
return modelBytesExceeded;
}
Expand Down Expand Up @@ -479,6 +501,7 @@ public int hashCode() {
jobId,
modelBytes,
peakModelBytes,
actualMemoryUsageBytes,
modelBytesExceeded,
modelBytesMemoryLimit,
totalByFieldCount,
Expand Down Expand Up @@ -517,6 +540,7 @@ public boolean equals(Object other) {

return this.modelBytes == that.modelBytes
&& Objects.equals(this.peakModelBytes, that.peakModelBytes)
&& this.actualMemoryUsageBytes == that.actualMemoryUsageBytes
&& Objects.equals(this.modelBytesExceeded, that.modelBytesExceeded)
&& Objects.equals(this.modelBytesMemoryLimit, that.modelBytesMemoryLimit)
&& this.totalByFieldCount == that.totalByFieldCount
Expand All @@ -543,6 +567,7 @@ public static class Builder {
private final String jobId;
private long modelBytes;
private Long peakModelBytes;
private Long actualMemoryUsageBytes;
private Long modelBytesExceeded;
private Long modelBytesMemoryLimit;
private long totalByFieldCount;
Expand Down Expand Up @@ -573,6 +598,7 @@ public Builder(ModelSizeStats modelSizeStats) {
this.jobId = modelSizeStats.jobId;
this.modelBytes = modelSizeStats.modelBytes;
this.peakModelBytes = modelSizeStats.peakModelBytes;
this.actualMemoryUsageBytes = modelSizeStats.actualMemoryUsageBytes;
this.modelBytesExceeded = modelSizeStats.modelBytesExceeded;
this.modelBytesMemoryLimit = modelSizeStats.modelBytesMemoryLimit;
this.totalByFieldCount = modelSizeStats.totalByFieldCount;
Expand Down Expand Up @@ -603,6 +629,11 @@ public Builder setPeakModelBytes(long peakModelBytes) {
return this;
}

public Builder setActualMemoryUsageBytes(long actualMemoryUsageBytes) {
this.actualMemoryUsageBytes = actualMemoryUsageBytes;
return this;
}

public Builder setModelBytesExceeded(long modelBytesExceeded) {
this.modelBytesExceeded = modelBytesExceeded;
return this;
Expand Down Expand Up @@ -700,6 +731,7 @@ public ModelSizeStats build() {
jobId,
modelBytes,
peakModelBytes,
actualMemoryUsageBytes,
modelBytesExceeded,
modelBytesMemoryLimit,
totalByFieldCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,13 @@ void calculateEstablishedMemoryUsage(
handler.accept((storedPeak != null) ? storedPeak : latestModelSizeStats.getModelBytes());
return;
}
case ACTUAL_MEMORY_USAGE_BYTES -> {
Long storedActualMemoryUsageBytes = latestModelSizeStats.getActualMemoryUsageBytes();
handler.accept(
(storedActualMemoryUsageBytes != null) ? storedActualMemoryUsageBytes : latestModelSizeStats.getModelBytes()
);
return;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,8 @@ public ByteSizeValue getOpenProcessMemoryUsage() {
case MODEL_MEMORY_LIMIT -> Optional.ofNullable(modelSizeStats.getModelBytesMemoryLimit()).orElse(0L);
case CURRENT_MODEL_BYTES -> modelSizeStats.getModelBytes();
case PEAK_MODEL_BYTES -> Optional.ofNullable(modelSizeStats.getPeakModelBytes()).orElse(modelSizeStats.getModelBytes());
case ACTUAL_MEMORY_USAGE_BYTES -> Optional.ofNullable(modelSizeStats.getActualMemoryUsageBytes())
.orElse(modelSizeStats.getModelBytes());
};
memoryUsedBytes += Job.PROCESS_MEMORY_OVERHEAD.getBytes();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,12 @@ public void testGetOpenProcessMemoryUsage() {
long modelMemoryLimitBytes = ByteSizeValue.ofMb(randomIntBetween(10, 1000)).getBytes();
long peakModelBytes = randomLongBetween(100000, modelMemoryLimitBytes - 1);
long modelBytes = randomLongBetween(1, peakModelBytes - 1);
long actualMemoryUsageBytes = randomLongBetween(262144, peakModelBytes - 1);
AssignmentMemoryBasis assignmentMemoryBasis = randomFrom(AssignmentMemoryBasis.values());
modelSizeStats = new ModelSizeStats.Builder("foo").setModelBytesMemoryLimit(modelMemoryLimitBytes)
.setPeakModelBytes(peakModelBytes)
.setModelBytes(modelBytes)
.setActualMemoryUsageBytes(actualMemoryUsageBytes)
.setAssignmentMemoryBasis(assignmentMemoryBasis)
.build();
when(autodetectCommunicator.getModelSizeStats()).thenReturn(modelSizeStats);
Expand All @@ -850,6 +852,7 @@ public void testGetOpenProcessMemoryUsage() {
case MODEL_MEMORY_LIMIT -> modelMemoryLimitBytes;
case CURRENT_MODEL_BYTES -> modelBytes;
case PEAK_MODEL_BYTES -> peakModelBytes;
case ACTUAL_MEMORY_USAGE_BYTES -> actualMemoryUsageBytes;
};
assertThat(manager.getOpenProcessMemoryUsage(), equalTo(ByteSizeValue.ofBytes(expectedSizeBytes)));
}
Expand Down
Loading