Skip to content

[ML] Handle new actual_memory_usage_bytes field in model size stats. #126256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions docs/changelog/126256.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 126256
summary: Handle new `actual_memory_usage_bytes` field in model size stats
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ static TransportVersion def(int id) {
public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00);
public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0);
public static final TransportVersion ML_AD_SYSTEM_MEMORY_USAGE = def(9_048_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ public class ModelSizeStats implements ToXContentObject, Writeable {
*/
public static final ParseField MODEL_BYTES_FIELD = new ParseField("model_bytes");
public static final ParseField PEAK_MODEL_BYTES_FIELD = new ParseField("peak_model_bytes");
public static final ParseField SYSTEM_MEMORY_BYTES = new ParseField("system_memory_bytes");
public static final ParseField MAX_SYSTEM_MEMORY_BYTES = new ParseField("max_system_memory_bytes");
public static final ParseField MODEL_BYTES_EXCEEDED_FIELD = new ParseField("model_bytes_exceeded");
public static final ParseField MODEL_BYTES_MEMORY_LIMIT_FIELD = new ParseField("model_bytes_memory_limit");
public static final ParseField TOTAL_BY_FIELD_COUNT_FIELD = new ParseField("total_by_field_count");
Expand Down Expand Up @@ -74,6 +76,8 @@ private static ConstructingObjectParser<Builder, Void> createParser(boolean igno
parser.declareString((modelSizeStat, s) -> {}, Result.RESULT_TYPE);
parser.declareLong(Builder::setModelBytes, MODEL_BYTES_FIELD);
parser.declareLong(Builder::setPeakModelBytes, PEAK_MODEL_BYTES_FIELD);
parser.declareLong(Builder::setSystemMemoryBytes, SYSTEM_MEMORY_BYTES);
parser.declareLong(Builder::setMaxSystemMemoryBytes, MAX_SYSTEM_MEMORY_BYTES);
parser.declareLong(Builder::setModelBytesExceeded, MODEL_BYTES_EXCEEDED_FIELD);
parser.declareLong(Builder::setModelBytesMemoryLimit, MODEL_BYTES_MEMORY_LIMIT_FIELD);
parser.declareLong(Builder::setBucketAllocationFailuresCount, BUCKET_ALLOCATION_FAILURES_COUNT_FIELD);
Expand Down Expand Up @@ -152,14 +156,18 @@ public String toString() {
* 1. The job's model_memory_limit
* 2. The current model memory, i.e. what's reported in model_bytes of this object
* 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object
* 4. The system memory, i.e. what's reported in system_memory_bytes of this object
* 5. The max system memory, i.e. what's reported in max_system_memory_bytes of this object
* The field storing this enum can also be <code>null</code>, which means the
* assignment code will decide on the fly - this was the old behaviour prior
* to 7.11.
*/
public enum AssignmentMemoryBasis implements Writeable {
MODEL_MEMORY_LIMIT,
CURRENT_MODEL_BYTES,
PEAK_MODEL_BYTES;
PEAK_MODEL_BYTES,
SYSTEM_MEMORY_BYTES,
MAX_SYSTEM_MEMORY_BYTES,;

public static AssignmentMemoryBasis fromString(String statusName) {
return valueOf(statusName.trim().toUpperCase(Locale.ROOT));
Expand All @@ -183,6 +191,8 @@ public String toString() {
private final String jobId;
private final long modelBytes;
private final Long peakModelBytes;
private final Long systemMemoryUsageBytes;
private final Long maxSystemMemoryUsageBytes;
private final Long modelBytesExceeded;
private final Long modelBytesMemoryLimit;
private final long totalByFieldCount;
Expand All @@ -206,6 +216,8 @@ private ModelSizeStats(
String jobId,
long modelBytes,
Long peakModelBytes,
Long systemMemoryUsageBytes,
Long maxSystemMemoryUsageBytes,
Long modelBytesExceeded,
Long modelBytesMemoryLimit,
long totalByFieldCount,
Expand All @@ -228,6 +240,8 @@ private ModelSizeStats(
this.jobId = jobId;
this.modelBytes = modelBytes;
this.peakModelBytes = peakModelBytes;
this.systemMemoryUsageBytes = systemMemoryUsageBytes;
this.maxSystemMemoryUsageBytes = maxSystemMemoryUsageBytes;
this.modelBytesExceeded = modelBytesExceeded;
this.modelBytesMemoryLimit = modelBytesMemoryLimit;
this.totalByFieldCount = totalByFieldCount;
Expand All @@ -252,6 +266,13 @@ public ModelSizeStats(StreamInput in) throws IOException {
jobId = in.readString();
modelBytes = in.readVLong();
peakModelBytes = in.readOptionalLong();
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_AD_SYSTEM_MEMORY_USAGE)) {
systemMemoryUsageBytes = in.readOptionalLong();
maxSystemMemoryUsageBytes = in.readOptionalLong();
} else {
systemMemoryUsageBytes = null;
maxSystemMemoryUsageBytes = null;
}
modelBytesExceeded = in.readOptionalLong();
modelBytesMemoryLimit = in.readOptionalLong();
totalByFieldCount = in.readVLong();
Expand Down Expand Up @@ -293,6 +314,10 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(jobId);
out.writeVLong(modelBytes);
out.writeOptionalLong(peakModelBytes);
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_AD_SYSTEM_MEMORY_USAGE)) {
out.writeOptionalLong(systemMemoryUsageBytes);
out.writeOptionalLong(maxSystemMemoryUsageBytes);
}
out.writeOptionalLong(modelBytesExceeded);
out.writeOptionalLong(modelBytesMemoryLimit);
out.writeVLong(totalByFieldCount);
Expand Down Expand Up @@ -339,6 +364,12 @@ public XContentBuilder doXContentBody(XContentBuilder builder) throws IOExceptio
if (peakModelBytes != null) {
builder.field(PEAK_MODEL_BYTES_FIELD.getPreferredName(), peakModelBytes);
}
if (systemMemoryUsageBytes != null) {
builder.field(SYSTEM_MEMORY_BYTES.getPreferredName(), systemMemoryUsageBytes);
}
if (maxSystemMemoryUsageBytes != null) {
builder.field(MAX_SYSTEM_MEMORY_BYTES.getPreferredName(), maxSystemMemoryUsageBytes);
}
if (modelBytesExceeded != null) {
builder.field(MODEL_BYTES_EXCEEDED_FIELD.getPreferredName(), modelBytesExceeded);
}
Expand Down Expand Up @@ -391,6 +422,14 @@ public Long getPeakModelBytes() {
return peakModelBytes;
}

public Long getSystemMemoryBytes() {
return systemMemoryUsageBytes;
}

public Long getMaxSystemMemoryBytes() {
return maxSystemMemoryUsageBytes;
}

public Long getModelBytesExceeded() {
return modelBytesExceeded;
}
Expand Down Expand Up @@ -479,6 +518,8 @@ public int hashCode() {
jobId,
modelBytes,
peakModelBytes,
systemMemoryUsageBytes,
maxSystemMemoryUsageBytes,
modelBytesExceeded,
modelBytesMemoryLimit,
totalByFieldCount,
Expand Down Expand Up @@ -517,6 +558,8 @@ public boolean equals(Object other) {

return this.modelBytes == that.modelBytes
&& Objects.equals(this.peakModelBytes, that.peakModelBytes)
&& this.systemMemoryUsageBytes == that.systemMemoryUsageBytes
&& this.maxSystemMemoryUsageBytes == that.maxSystemMemoryUsageBytes
&& Objects.equals(this.modelBytesExceeded, that.modelBytesExceeded)
&& Objects.equals(this.modelBytesMemoryLimit, that.modelBytesMemoryLimit)
&& this.totalByFieldCount == that.totalByFieldCount
Expand All @@ -543,6 +586,8 @@ public static class Builder {
private final String jobId;
private long modelBytes;
private Long peakModelBytes;
private Long systemMemoryUsageBytes;
private Long maxSystemMemoryUsageBytes;
private Long modelBytesExceeded;
private Long modelBytesMemoryLimit;
private long totalByFieldCount;
Expand Down Expand Up @@ -573,6 +618,8 @@ public Builder(ModelSizeStats modelSizeStats) {
this.jobId = modelSizeStats.jobId;
this.modelBytes = modelSizeStats.modelBytes;
this.peakModelBytes = modelSizeStats.peakModelBytes;
this.systemMemoryUsageBytes = modelSizeStats.systemMemoryUsageBytes;
this.maxSystemMemoryUsageBytes = modelSizeStats.maxSystemMemoryUsageBytes;
this.modelBytesExceeded = modelSizeStats.modelBytesExceeded;
this.modelBytesMemoryLimit = modelSizeStats.modelBytesMemoryLimit;
this.totalByFieldCount = modelSizeStats.totalByFieldCount;
Expand Down Expand Up @@ -603,6 +650,16 @@ public Builder setPeakModelBytes(long peakModelBytes) {
return this;
}

public Builder setSystemMemoryBytes(long systemMemoryUsageBytes) {
this.systemMemoryUsageBytes = systemMemoryUsageBytes;
return this;
}

public Builder setMaxSystemMemoryBytes(long maxSystemMemoryUsageBytes) {
this.maxSystemMemoryUsageBytes = maxSystemMemoryUsageBytes;
return this;
}

public Builder setModelBytesExceeded(long modelBytesExceeded) {
this.modelBytesExceeded = modelBytesExceeded;
return this;
Expand Down Expand Up @@ -700,6 +757,8 @@ public ModelSizeStats build() {
jobId,
modelBytes,
peakModelBytes,
systemMemoryUsageBytes,
maxSystemMemoryUsageBytes,
modelBytesExceeded,
modelBytesMemoryLimit,
totalByFieldCount,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,18 @@ void calculateEstablishedMemoryUsage(
handler.accept((storedPeak != null) ? storedPeak : latestModelSizeStats.getModelBytes());
return;
}
case SYSTEM_MEMORY_BYTES -> {
Long storedSystemMemoryBytes = latestModelSizeStats.getSystemMemoryBytes();
handler.accept((storedSystemMemoryBytes != null) ? storedSystemMemoryBytes : latestModelSizeStats.getModelBytes());
return;
}
case MAX_SYSTEM_MEMORY_BYTES -> {
Long storedMaxSystemMemoryBytes = latestModelSizeStats.getMaxSystemMemoryBytes();
handler.accept(
(storedMaxSystemMemoryBytes != null) ? storedMaxSystemMemoryBytes : latestModelSizeStats.getModelBytes()
);
return;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,10 @@ public ByteSizeValue getOpenProcessMemoryUsage() {
case MODEL_MEMORY_LIMIT -> Optional.ofNullable(modelSizeStats.getModelBytesMemoryLimit()).orElse(0L);
case CURRENT_MODEL_BYTES -> modelSizeStats.getModelBytes();
case PEAK_MODEL_BYTES -> Optional.ofNullable(modelSizeStats.getPeakModelBytes()).orElse(modelSizeStats.getModelBytes());
case SYSTEM_MEMORY_BYTES -> Optional.ofNullable(modelSizeStats.getSystemMemoryBytes())
.orElse(modelSizeStats.getModelBytes());
case MAX_SYSTEM_MEMORY_BYTES -> Optional.ofNullable(modelSizeStats.getMaxSystemMemoryBytes())
.orElse(modelSizeStats.getModelBytes());
};
memoryUsedBytes += Job.PROCESS_MEMORY_OVERHEAD.getBytes();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -834,10 +834,14 @@ public void testGetOpenProcessMemoryUsage() {
long modelMemoryLimitBytes = ByteSizeValue.ofMb(randomIntBetween(10, 1000)).getBytes();
long peakModelBytes = randomLongBetween(100000, modelMemoryLimitBytes - 1);
long modelBytes = randomLongBetween(1, peakModelBytes - 1);
long systemMemoryUsageBytes = randomLongBetween(262144, peakModelBytes - 1);
long maxSystemMemoryUsageBytes = randomLongBetween(266240, peakModelBytes - 1);
AssignmentMemoryBasis assignmentMemoryBasis = randomFrom(AssignmentMemoryBasis.values());
modelSizeStats = new ModelSizeStats.Builder("foo").setModelBytesMemoryLimit(modelMemoryLimitBytes)
.setPeakModelBytes(peakModelBytes)
.setModelBytes(modelBytes)
.setSystemMemoryBytes(systemMemoryUsageBytes)
.setMaxSystemMemoryBytes(maxSystemMemoryUsageBytes)
.setAssignmentMemoryBasis(assignmentMemoryBasis)
.build();
when(autodetectCommunicator.getModelSizeStats()).thenReturn(modelSizeStats);
Expand All @@ -850,6 +854,8 @@ public void testGetOpenProcessMemoryUsage() {
case MODEL_MEMORY_LIMIT -> modelMemoryLimitBytes;
case CURRENT_MODEL_BYTES -> modelBytes;
case PEAK_MODEL_BYTES -> peakModelBytes;
case SYSTEM_MEMORY_BYTES -> systemMemoryUsageBytes;
case MAX_SYSTEM_MEMORY_BYTES -> maxSystemMemoryUsageBytes;
};
assertThat(manager.getOpenProcessMemoryUsage(), equalTo(ByteSizeValue.ofBytes(expectedSizeBytes)));
}
Expand Down
Loading