Skip to content

Commit 30d5a9d

Browse files
[8.19] Semantic Text Chunking Indexing Pressure (elastic#125517) (elastic#127463)
* Semantic Text Chunking Indexing Pressure (elastic#125517) We have observed many OOMs due to the memory required to inject chunked inference results for semantic_text fields. This PR uses coordinating indexing pressure to account for this memory usage. When indexing pressure memory usage exceeds the threshold set by indexing_pressure.memory.limit, chunked inference result injection will be suspended to prevent OOMs. (cherry picked from commit 85713f7) # Conflicts: # server/src/main/java/org/elasticsearch/node/NodeConstruction.java # server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java * [CI] Auto commit changes from spotless --------- Co-authored-by: elasticsearchmachine <[email protected]>
1 parent 7eb393b commit 30d5a9d

File tree

9 files changed

+617
-27
lines changed

9 files changed

+617
-27
lines changed

docs/changelog/125517.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125517
2+
summary: Semantic Text Chunking Indexing Pressure
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

server/src/main/java/org/elasticsearch/index/IndexingPressure.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,15 @@ public Incremental startIncrementalCoordinating(int operations, long bytes, bool
146146
}
147147

148148
public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) {
149-
Coordinating coordinating = new Coordinating(forceExecution);
149+
Coordinating coordinating = createCoordinatingOperation(forceExecution);
150150
coordinating.increment(operations, bytes);
151151
return coordinating;
152152
}
153153

154+
public Coordinating createCoordinatingOperation(boolean forceExecution) {
155+
return new Coordinating(forceExecution);
156+
}
157+
154158
public class Incremental implements Releasable {
155159

156160
private final AtomicBoolean closed = new AtomicBoolean();
@@ -243,7 +247,7 @@ public Coordinating(boolean forceExecution) {
243247
this.forceExecution = forceExecution;
244248
}
245249

246-
private void increment(int operations, long bytes) {
250+
public void increment(int operations, long bytes) {
247251
assert closed.get() == false;
248252
long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes);
249253
long replicaWriteBytes = currentReplicaBytes.get();

server/src/main/java/org/elasticsearch/node/NodeConstruction.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -936,6 +936,8 @@ public Map<String, String> queryFields() {
936936
metadataCreateIndexService
937937
);
938938

939+
final IndexingPressure indexingLimits = new IndexingPressure(settings);
940+
939941
PluginServiceInstances pluginServices = new PluginServiceInstances(
940942
client,
941943
clusterService,
@@ -957,7 +959,8 @@ public Map<String, String> queryFields() {
957959
dataStreamGlobalRetentionSettings,
958960
documentParsingProvider,
959961
taskManager,
960-
slowLogFieldProvider
962+
slowLogFieldProvider,
963+
indexingLimits
961964
);
962965

963966
Collection<?> pluginComponents = pluginsService.flatMap(plugin -> {
@@ -990,7 +993,6 @@ public Map<String, String> queryFields() {
990993
.map(TerminationHandlerProvider::handler);
991994
terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);
992995

993-
final IndexingPressure indexingLimits = new IndexingPressure(settings);
994996
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);
995997

996998
ActionModule actionModule = new ActionModule(

server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.elasticsearch.env.Environment;
2020
import org.elasticsearch.env.NodeEnvironment;
2121
import org.elasticsearch.features.FeatureService;
22+
import org.elasticsearch.index.IndexingPressure;
2223
import org.elasticsearch.index.SlowLogFieldProvider;
2324
import org.elasticsearch.indices.IndicesService;
2425
import org.elasticsearch.indices.SystemIndices;
@@ -53,5 +54,6 @@ public record PluginServiceInstances(
5354
DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings,
5455
DocumentParsingProvider documentParsingProvider,
5556
TaskManager taskManager,
56-
SlowLogFieldProvider slowLogFieldProvider
57+
SlowLogFieldProvider slowLogFieldProvider,
58+
IndexingPressure indexingPressure
5759
) implements Plugin.PluginServices {}

server/src/main/java/org/elasticsearch/plugins/Plugin.java

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.elasticsearch.features.FeatureService;
2828
import org.elasticsearch.index.IndexModule;
2929
import org.elasticsearch.index.IndexSettingProvider;
30+
import org.elasticsearch.index.IndexingPressure;
3031
import org.elasticsearch.index.SlowLogFieldProvider;
3132
import org.elasticsearch.indices.IndicesService;
3233
import org.elasticsearch.indices.SystemIndices;
@@ -179,6 +180,11 @@ public interface PluginServices {
179180
* Provider for additional SlowLog fields
180181
*/
181182
SlowLogFieldProvider slowLogFieldProvider();
183+
184+
/**
185+
* Provider for indexing pressure
186+
*/
187+
IndexingPressure indexingPressure();
182188
}
183189

184190
/**

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ public Collection<?> createComponents(PluginServices services) {
328328
services.clusterService(),
329329
serviceRegistry,
330330
modelRegistry.get(),
331-
getLicenseState()
331+
getLicenseState(),
332+
services.indexingPressure()
332333
);
333334
shardBulkInferenceActionFilter.set(actionFilter);
334335

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java

+100-11
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
import org.elasticsearch.common.settings.Setting;
2929
import org.elasticsearch.common.unit.ByteSizeValue;
3030
import org.elasticsearch.common.util.concurrent.AtomicArray;
31+
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
3132
import org.elasticsearch.common.xcontent.XContentHelper;
3233
import org.elasticsearch.common.xcontent.support.XContentMapValues;
3334
import org.elasticsearch.core.Nullable;
3435
import org.elasticsearch.core.Releasable;
3536
import org.elasticsearch.core.TimeValue;
37+
import org.elasticsearch.index.IndexingPressure;
3638
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
3739
import org.elasticsearch.inference.ChunkInferenceInput;
3840
import org.elasticsearch.inference.ChunkedInference;
@@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
108110
private final InferenceServiceRegistry inferenceServiceRegistry;
109111
private final ModelRegistry modelRegistry;
110112
private final XPackLicenseState licenseState;
113+
private final IndexingPressure indexingPressure;
111114
private volatile long batchSizeInBytes;
112115

113116
public ShardBulkInferenceActionFilter(
114117
ClusterService clusterService,
115118
InferenceServiceRegistry inferenceServiceRegistry,
116119
ModelRegistry modelRegistry,
117-
XPackLicenseState licenseState
120+
XPackLicenseState licenseState,
121+
IndexingPressure indexingPressure
118122
) {
119123
this.clusterService = clusterService;
120124
this.inferenceServiceRegistry = inferenceServiceRegistry;
121125
this.modelRegistry = modelRegistry;
122126
this.licenseState = licenseState;
127+
this.indexingPressure = indexingPressure;
123128
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
124129
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
125130
}
@@ -145,8 +150,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
145150
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
146151
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
147152
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
148-
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
149-
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
153+
// Maintain coordinating indexing pressure from inference until the indexing operations are complete
154+
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
155+
Runnable onInferenceCompletion = () -> chain.proceed(
156+
task,
157+
action,
158+
request,
159+
ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
160+
);
161+
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
150162
return;
151163
}
152164
}
@@ -156,11 +168,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
156168
private void processBulkShardRequest(
157169
Map<String, InferenceFieldMetadata> fieldInferenceMap,
158170
BulkShardRequest bulkShardRequest,
159-
Runnable onCompletion
171+
Runnable onCompletion,
172+
IndexingPressure.Coordinating coordinatingIndexingPressure
160173
) {
161174
var index = clusterService.state().getMetadata().index(bulkShardRequest.index());
162175
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
163-
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run();
176+
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
177+
.run();
164178
}
165179

166180
private record InferenceProvider(InferenceService service, Model model) {}
@@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
230244
private final BulkShardRequest bulkShardRequest;
231245
private final Runnable onCompletion;
232246
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
247+
private final IndexingPressure.Coordinating coordinatingIndexingPressure;
233248

234249
private AsyncBulkShardInferenceAction(
235250
boolean useLegacyFormat,
236251
Map<String, InferenceFieldMetadata> fieldInferenceMap,
237252
BulkShardRequest bulkShardRequest,
238-
Runnable onCompletion
253+
Runnable onCompletion,
254+
IndexingPressure.Coordinating coordinatingIndexingPressure
239255
) {
240256
this.useLegacyFormat = useLegacyFormat;
241257
this.fieldInferenceMap = fieldInferenceMap;
242258
this.bulkShardRequest = bulkShardRequest;
243259
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
244260
this.onCompletion = onCompletion;
261+
this.coordinatingIndexingPressure = coordinatingIndexingPressure;
245262
}
246263

247264
@Override
@@ -429,9 +446,9 @@ public void onFailure(Exception exc) {
429446
*/
430447
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
431448
boolean isUpdateRequest = false;
432-
final IndexRequest indexRequest;
449+
final IndexRequestWithIndexingPressure indexRequest;
433450
if (item.request() instanceof IndexRequest ir) {
434-
indexRequest = ir;
451+
indexRequest = new IndexRequestWithIndexingPressure(ir);
435452
} else if (item.request() instanceof UpdateRequest updateRequest) {
436453
isUpdateRequest = true;
437454
if (updateRequest.script() != null) {
@@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
445462
);
446463
return 0;
447464
}
448-
indexRequest = updateRequest.doc();
465+
indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc());
449466
} else {
450467
// ignore delete request
451468
return 0;
452469
}
453470

454-
final Map<String, Object> docMap = indexRequest.sourceAsMap();
471+
final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap();
455472
long inputLength = 0;
456473
for (var entry : fieldInferenceMap.values()) {
457474
String field = entry.getName();
@@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
487504
* This ensures that the field is treated as intentionally cleared,
488505
* preventing any unintended carryover of prior inference results.
489506
*/
507+
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
508+
return inputLength;
509+
}
510+
490511
var slot = ensureResponseAccumulatorSlot(itemIndex);
491512
slot.addOrUpdateResponse(
492513
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
508529
}
509530
continue;
510531
}
532+
511533
var slot = ensureResponseAccumulatorSlot(itemIndex);
512534
final List<String> values;
513535
try {
@@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
525547
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
526548
int offsetAdjustment = 0;
527549
for (String v : values) {
528-
inputLength += v.length();
550+
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
551+
return inputLength;
552+
}
553+
529554
if (v.isBlank()) {
530555
slot.addOrUpdateResponse(
531556
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
@@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
534559
requests.add(
535560
new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)
536561
);
562+
inputLength += v.length();
537563
}
538564

539565
// When using the inference metadata fields format, all the input values are concatenated so that the
@@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
543569
}
544570
}
545571
}
572+
546573
return inputLength;
547574
}
548575

576+
private static class IndexRequestWithIndexingPressure {
577+
private final IndexRequest indexRequest;
578+
private boolean indexingPressureIncremented;
579+
580+
private IndexRequestWithIndexingPressure(IndexRequest indexRequest) {
581+
this.indexRequest = indexRequest;
582+
this.indexingPressureIncremented = false;
583+
}
584+
585+
private IndexRequest getIndexRequest() {
586+
return indexRequest;
587+
}
588+
589+
private boolean isIndexingPressureIncremented() {
590+
return indexingPressureIncremented;
591+
}
592+
593+
private void setIndexingPressureIncremented() {
594+
this.indexingPressureIncremented = true;
595+
}
596+
}
597+
598+
private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
599+
boolean success = true;
600+
if (indexRequest.isIndexingPressureIncremented() == false) {
601+
try {
602+
// Track operation count as one operation per document source update
603+
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
604+
indexRequest.setIndexingPressureIncremented();
605+
} catch (EsRejectedExecutionException e) {
606+
addInferenceResponseFailure(
607+
itemIndex,
608+
new InferenceException(
609+
"Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]",
610+
e
611+
)
612+
);
613+
success = false;
614+
}
615+
}
616+
617+
return success;
618+
}
619+
549620
private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
550621
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
551622
if (acc == null) {
@@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
622693
inferenceFieldsMap.put(fieldName, result);
623694
}
624695

696+
BytesReference originalSource = indexRequest.source();
625697
if (useLegacyFormat) {
626698
var newDocMap = indexRequest.sourceAsMap();
627699
for (var entry : inferenceFieldsMap.entrySet()) {
@@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
634706
indexRequest.source(builder);
635707
}
636708
}
709+
long modifiedSourceSize = indexRequest.source().ramBytesUsed();
710+
711+
// Add the indexing pressure from the source modifications.
712+
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
713+
// in addFieldInferenceRequests.
714+
try {
715+
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
716+
} catch (EsRejectedExecutionException e) {
717+
indexRequest.source(originalSource, indexRequest.getContentType());
718+
item.abort(
719+
item.index(),
720+
new InferenceException(
721+
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
722+
e
723+
)
724+
);
725+
}
637726
}
638727
}
639728

0 commit comments

Comments
 (0)