Skip to content

Commit 85713f7

Browse files
authored
Semantic Text Chunking Indexing Pressure (#125517)
We have observed many OOMs due to the memory required to inject chunked inference results for semantic_text fields. This PR uses coordinating indexing pressure to account for this memory usage. When indexing pressure memory usage exceeds the threshold set by indexing_pressure.memory.limit, chunked inference result injection will be suspended to prevent OOMs.
1 parent 235867c commit 85713f7

File tree

9 files changed

+621
-27
lines changed

9 files changed

+621
-27
lines changed

docs/changelog/125517.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125517
2+
summary: Semantic Text Chunking Indexing Pressure
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/IndexingPressure.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,15 @@ public Incremental startIncrementalCoordinating(int operations, long bytes, bool
157157
}
158158

159159
public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) {
160-
Coordinating coordinating = new Coordinating(forceExecution);
160+
Coordinating coordinating = createCoordinatingOperation(forceExecution);
161161
coordinating.increment(operations, bytes);
162162
return coordinating;
163163
}
164164

165+
public Coordinating createCoordinatingOperation(boolean forceExecution) {
166+
return new Coordinating(forceExecution);
167+
}
168+
165169
public class Incremental implements Releasable {
166170

167171
private final AtomicBoolean closed = new AtomicBoolean();
@@ -254,7 +258,7 @@ public Coordinating(boolean forceExecution) {
254258
this.forceExecution = forceExecution;
255259
}
256260

257-
private void increment(int operations, long bytes) {
261+
public void increment(int operations, long bytes) {
258262
assert closed.get() == false;
259263
long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes);
260264
long replicaWriteBytes = currentReplicaBytes.get();

server/src/main/java/org/elasticsearch/node/NodeConstruction.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,8 @@ public Map<String, String> queryFields() {
928928
metadataCreateIndexService
929929
);
930930

931+
final IndexingPressure indexingLimits = new IndexingPressure(settings);
932+
931933
PluginServiceInstances pluginServices = new PluginServiceInstances(
932934
client,
933935
clusterService,
@@ -950,7 +952,8 @@ public Map<String, String> queryFields() {
950952
documentParsingProvider,
951953
taskManager,
952954
projectResolver,
953-
slowLogFieldProvider
955+
slowLogFieldProvider,
956+
indexingLimits
954957
);
955958

956959
Collection<?> pluginComponents = pluginsService.flatMap(plugin -> {
@@ -983,7 +986,6 @@ public Map<String, String> queryFields() {
983986
.map(TerminationHandlerProvider::handler);
984987
terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);
985988

986-
final IndexingPressure indexingLimits = new IndexingPressure(settings);
987989
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);
988990

989991
final ResponseCollectorService responseCollectorService = new ResponseCollectorService(clusterService);

server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.elasticsearch.env.Environment;
2121
import org.elasticsearch.env.NodeEnvironment;
2222
import org.elasticsearch.features.FeatureService;
23+
import org.elasticsearch.index.IndexingPressure;
2324
import org.elasticsearch.index.SlowLogFieldProvider;
2425
import org.elasticsearch.indices.IndicesService;
2526
import org.elasticsearch.indices.SystemIndices;
@@ -55,5 +56,6 @@ public record PluginServiceInstances(
5556
DocumentParsingProvider documentParsingProvider,
5657
TaskManager taskManager,
5758
ProjectResolver projectResolver,
58-
SlowLogFieldProvider slowLogFieldProvider
59+
SlowLogFieldProvider slowLogFieldProvider,
60+
IndexingPressure indexingPressure
5961
) implements Plugin.PluginServices {}

server/src/main/java/org/elasticsearch/plugins/Plugin.java

+6
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.elasticsearch.features.FeatureService;
3030
import org.elasticsearch.index.IndexModule;
3131
import org.elasticsearch.index.IndexSettingProvider;
32+
import org.elasticsearch.index.IndexingPressure;
3233
import org.elasticsearch.index.SlowLogFieldProvider;
3334
import org.elasticsearch.indices.IndicesService;
3435
import org.elasticsearch.indices.SystemIndices;
@@ -186,6 +187,11 @@ public interface PluginServices {
186187
* Provider for additional SlowLog fields
187188
*/
188189
SlowLogFieldProvider slowLogFieldProvider();
190+
191+
/**
192+
* Provider for indexing pressure
193+
*/
194+
IndexingPressure indexingPressure();
189195
}
190196

191197
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,13 @@ public Collection<?> createComponents(PluginServices services) {
321321
}
322322
inferenceServiceRegistry.set(serviceRegistry);
323323

324-
var actionFilter = new ShardBulkInferenceActionFilter(services.clusterService(), serviceRegistry, modelRegistry, getLicenseState());
324+
var actionFilter = new ShardBulkInferenceActionFilter(
325+
services.clusterService(),
326+
serviceRegistry,
327+
modelRegistry,
328+
getLicenseState(),
329+
services.indexingPressure()
330+
);
325331
shardBulkInferenceActionFilter.set(actionFilter);
326332

327333
var meterRegistry = services.telemetryProvider().getMeterRegistry();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

+100-11
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@
2929
import org.elasticsearch.common.settings.Setting;
3030
import org.elasticsearch.common.unit.ByteSizeValue;
3131
import org.elasticsearch.common.util.concurrent.AtomicArray;
32+
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
3233
import org.elasticsearch.common.xcontent.XContentHelper;
3334
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3435
import org.elasticsearch.core.Nullable;
3536
import org.elasticsearch.core.Releasable;
3637
import org.elasticsearch.core.TimeValue;
38+
import org.elasticsearch.index.IndexingPressure;
3739
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
3840
import org.elasticsearch.inference.ChunkInferenceInput;
3941
import org.elasticsearch.inference.ChunkedInference;
@@ -109,18 +111,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
109111
private final InferenceServiceRegistry inferenceServiceRegistry;
110112
private final ModelRegistry modelRegistry;
111113
private final XPackLicenseState licenseState;
114+
private final IndexingPressure indexingPressure;
112115
private volatile long batchSizeInBytes;
113116

114117
public ShardBulkInferenceActionFilter(
115118
ClusterService clusterService,
116119
InferenceServiceRegistry inferenceServiceRegistry,
117120
ModelRegistry modelRegistry,
118-
XPackLicenseState licenseState
121+
XPackLicenseState licenseState,
122+
IndexingPressure indexingPressure
119123
) {
120124
this.clusterService = clusterService;
121125
this.inferenceServiceRegistry = inferenceServiceRegistry;
122126
this.modelRegistry = modelRegistry;
123127
this.licenseState = licenseState;
128+
this.indexingPressure = indexingPressure;
124129
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
125130
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
126131
}
@@ -146,8 +151,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
146151
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
147152
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
148153
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
149-
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
150-
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
154+
// Maintain coordinating indexing pressure from inference until the indexing operations are complete
155+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
156+
Runnable onInferenceCompletion = () -> chain.proceed(
157+
task,
158+
action,
159+
request,
160+
ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
161+
);
162+
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
151163
return;
152164
}
153165
}
@@ -157,12 +169,14 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
157169
private void processBulkShardRequest(
158170
Map<String, InferenceFieldMetadata> fieldInferenceMap,
159171
BulkShardRequest bulkShardRequest,
160-
Runnable onCompletion
172+
Runnable onCompletion,
173+
IndexingPressure.Coordinating coordinatingIndexingPressure
161174
) {
162175
final ProjectMetadata project = clusterService.state().getMetadata().getProject();
163176
var index = project.index(bulkShardRequest.index());
164177
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
165-
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run();
178+
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
179+
.run();
166180
}
167181

168182
private record InferenceProvider(InferenceService service, Model model) {}
@@ -232,18 +246,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
232246
private final BulkShardRequest bulkShardRequest;
233247
private final Runnable onCompletion;
234248
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
249+
private final IndexingPressure.Coordinating coordinatingIndexingPressure;
235250

236251
private AsyncBulkShardInferenceAction(
237252
boolean useLegacyFormat,
238253
Map<String, InferenceFieldMetadata> fieldInferenceMap,
239254
BulkShardRequest bulkShardRequest,
240-
Runnable onCompletion
255+
Runnable onCompletion,
256+
IndexingPressure.Coordinating coordinatingIndexingPressure
241257
) {
242258
this.useLegacyFormat = useLegacyFormat;
243259
this.fieldInferenceMap = fieldInferenceMap;
244260
this.bulkShardRequest = bulkShardRequest;
245261
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
246262
this.onCompletion = onCompletion;
263+
this.coordinatingIndexingPressure = coordinatingIndexingPressure;
247264
}
248265

249266
@Override
@@ -431,9 +448,9 @@ public void onFailure(Exception exc) {
431448
*/
432449
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
433450
boolean isUpdateRequest = false;
434-
final IndexRequest indexRequest;
451+
final IndexRequestWithIndexingPressure indexRequest;
435452
if (item.request() instanceof IndexRequest ir) {
436-
indexRequest = ir;
453+
indexRequest = new IndexRequestWithIndexingPressure(ir);
437454
} else if (item.request() instanceof UpdateRequest updateRequest) {
438455
isUpdateRequest = true;
439456
if (updateRequest.script() != null) {
@@ -447,13 +464,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
447464
);
448465
return 0;
449466
}
450-
indexRequest = updateRequest.doc();
467+
indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc());
451468
} else {
452469
// ignore delete request
453470
return 0;
454471
}
455472

456-
final Map<String, Object> docMap = indexRequest.sourceAsMap();
473+
final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap();
457474
long inputLength = 0;
458475
for (var entry : fieldInferenceMap.values()) {
459476
String field = entry.getName();
@@ -489,6 +506,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
489506
* This ensures that the field is treated as intentionally cleared,
490507
* preventing any unintended carryover of prior inference results.
491508
*/
509+
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
510+
return inputLength;
511+
}
512+
492513
var slot = ensureResponseAccumulatorSlot(itemIndex);
493514
slot.addOrUpdateResponse(
494515
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -510,6 +531,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
510531
}
511532
continue;
512533
}
534+
513535
var slot = ensureResponseAccumulatorSlot(itemIndex);
514536
final List<String> values;
515537
try {
@@ -527,7 +549,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
527549
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
528550
int offsetAdjustment = 0;
529551
for (String v : values) {
530-
inputLength += v.length();
552+
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
553+
return inputLength;
554+
}
555+
531556
if (v.isBlank()) {
532557
slot.addOrUpdateResponse(
533558
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -536,6 +561,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
536561
requests.add(
537562
new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)
538563
);
564+
inputLength += v.length();
539565
}
540566

541567
// When using the inference metadata fields format, all the input values are concatenated so that the
@@ -545,9 +571,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
545571
}
546572
}
547573
}
574+
548575
return inputLength;
549576
}
550577

578+
private static class IndexRequestWithIndexingPressure {
579+
private final IndexRequest indexRequest;
580+
private boolean indexingPressureIncremented;
581+
582+
private IndexRequestWithIndexingPressure(IndexRequest indexRequest) {
583+
this.indexRequest = indexRequest;
584+
this.indexingPressureIncremented = false;
585+
}
586+
587+
private IndexRequest getIndexRequest() {
588+
return indexRequest;
589+
}
590+
591+
private boolean isIndexingPressureIncremented() {
592+
return indexingPressureIncremented;
593+
}
594+
595+
private void setIndexingPressureIncremented() {
596+
this.indexingPressureIncremented = true;
597+
}
598+
}
599+
600+
private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
601+
boolean success = true;
602+
if (indexRequest.isIndexingPressureIncremented() == false) {
603+
try {
604+
// Track operation count as one operation per document source update
605+
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
606+
indexRequest.setIndexingPressureIncremented();
607+
} catch (EsRejectedExecutionException e) {
608+
addInferenceResponseFailure(
609+
itemIndex,
610+
new InferenceException(
611+
"Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]",
612+
e
613+
)
614+
);
615+
success = false;
616+
}
617+
}
618+
619+
return success;
620+
}
621+
551622
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
552623
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
553624
if (acc == null) {
@@ -624,6 +695,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
624695
inferenceFieldsMap.put(fieldName, result);
625696
}
626697

698+
BytesReference originalSource = indexRequest.source();
627699
if (useLegacyFormat) {
628700
var newDocMap = indexRequest.sourceAsMap();
629701
for (var entry : inferenceFieldsMap.entrySet()) {
@@ -636,6 +708,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
636708
indexRequest.source(builder);
637709
}
638710
}
711+
long modifiedSourceSize = indexRequest.source().ramBytesUsed();
712+
713+
// Add the indexing pressure from the source modifications.
714+
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
715+
// in addFieldInferenceRequests.
716+
try {
717+
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
718+
} catch (EsRejectedExecutionException e) {
719+
indexRequest.source(originalSource, indexRequest.getContentType());
720+
item.abort(
721+
item.index(),
722+
new InferenceException(
723+
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
724+
e
725+
)
726+
);
727+
}
639728
}
640729
}
641730

0 commit comments

Comments
 (0)