Skip to content

[8.19] Semantic Text Chunking Indexing Pressure (#125517) #127463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/125517.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 125517
summary: Semantic Text Chunking Indexing Pressure
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,15 @@ public Incremental startIncrementalCoordinating(int operations, long bytes, bool
}

public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) {
Coordinating coordinating = new Coordinating(forceExecution);
Coordinating coordinating = createCoordinatingOperation(forceExecution);
coordinating.increment(operations, bytes);
return coordinating;
}

public Coordinating createCoordinatingOperation(boolean forceExecution) {
return new Coordinating(forceExecution);
}

public class Incremental implements Releasable {

private final AtomicBoolean closed = new AtomicBoolean();
Expand Down Expand Up @@ -243,7 +247,7 @@ public Coordinating(boolean forceExecution) {
this.forceExecution = forceExecution;
}

private void increment(int operations, long bytes) {
public void increment(int operations, long bytes) {
assert closed.get() == false;
long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes);
long replicaWriteBytes = currentReplicaBytes.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,8 @@ public Map<String, String> queryFields() {
metadataCreateIndexService
);

final IndexingPressure indexingLimits = new IndexingPressure(settings);

PluginServiceInstances pluginServices = new PluginServiceInstances(
client,
clusterService,
Expand All @@ -957,7 +959,8 @@ public Map<String, String> queryFields() {
dataStreamGlobalRetentionSettings,
documentParsingProvider,
taskManager,
slowLogFieldProvider
slowLogFieldProvider,
indexingLimits
);

Collection<?> pluginComponents = pluginsService.flatMap(plugin -> {
Expand Down Expand Up @@ -990,7 +993,6 @@ public Map<String, String> queryFields() {
.map(TerminationHandlerProvider::handler);
terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null);

final IndexingPressure indexingLimits = new IndexingPressure(settings);
final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits);

ActionModule actionModule = new ActionModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.env.Environment;
import org.elasticsearch.env.NodeEnvironment;
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.SlowLogFieldProvider;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.SystemIndices;
Expand Down Expand Up @@ -53,5 +54,6 @@ public record PluginServiceInstances(
DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings,
DocumentParsingProvider documentParsingProvider,
TaskManager taskManager,
SlowLogFieldProvider slowLogFieldProvider
SlowLogFieldProvider slowLogFieldProvider,
IndexingPressure indexingPressure
) implements Plugin.PluginServices {}
6 changes: 6 additions & 0 deletions server/src/main/java/org/elasticsearch/plugins/Plugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.features.FeatureService;
import org.elasticsearch.index.IndexModule;
import org.elasticsearch.index.IndexSettingProvider;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.SlowLogFieldProvider;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.indices.SystemIndices;
Expand Down Expand Up @@ -179,6 +180,11 @@ public interface PluginServices {
* Provider for additional SlowLog fields
*/
SlowLogFieldProvider slowLogFieldProvider();

/**
* Provider for indexing pressure
*/
IndexingPressure indexingPressure();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ public Collection<?> createComponents(PluginServices services) {
services.clusterService(),
serviceRegistry,
modelRegistry.get(),
getLicenseState()
getLicenseState(),
services.indexingPressure()
);
shardBulkInferenceActionFilter.set(actionFilter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexingPressure;
import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper;
import org.elasticsearch.inference.ChunkInferenceInput;
import org.elasticsearch.inference.ChunkedInference;
Expand Down Expand Up @@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter {
private final InferenceServiceRegistry inferenceServiceRegistry;
private final ModelRegistry modelRegistry;
private final XPackLicenseState licenseState;
private final IndexingPressure indexingPressure;
private volatile long batchSizeInBytes;

public ShardBulkInferenceActionFilter(
ClusterService clusterService,
InferenceServiceRegistry inferenceServiceRegistry,
ModelRegistry modelRegistry,
XPackLicenseState licenseState
XPackLicenseState licenseState,
IndexingPressure indexingPressure
) {
this.clusterService = clusterService;
this.inferenceServiceRegistry = inferenceServiceRegistry;
this.modelRegistry = modelRegistry;
this.licenseState = licenseState;
this.indexingPressure = indexingPressure;
this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes();
clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize);
}
Expand All @@ -145,8 +150,15 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap();
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
// Maintain coordinating indexing pressure from inference until the indexing operations are complete
IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false);
Runnable onInferenceCompletion = () -> chain.proceed(
task,
action,
request,
ActionListener.releaseAfter(listener, coordinatingIndexingPressure)
);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure);
return;
}
}
Expand All @@ -156,11 +168,13 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
private void processBulkShardRequest(
Map<String, InferenceFieldMetadata> fieldInferenceMap,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
Runnable onCompletion,
IndexingPressure.Coordinating coordinatingIndexingPressure
) {
var index = clusterService.state().getMetadata().index(bulkShardRequest.index());
boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false;
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run();
new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure)
.run();
}

private record InferenceProvider(InferenceService service, Model model) {}
Expand Down Expand Up @@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable {
private final BulkShardRequest bulkShardRequest;
private final Runnable onCompletion;
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;
private final IndexingPressure.Coordinating coordinatingIndexingPressure;

private AsyncBulkShardInferenceAction(
boolean useLegacyFormat,
Map<String, InferenceFieldMetadata> fieldInferenceMap,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
Runnable onCompletion,
IndexingPressure.Coordinating coordinatingIndexingPressure
) {
this.useLegacyFormat = useLegacyFormat;
this.fieldInferenceMap = fieldInferenceMap;
this.bulkShardRequest = bulkShardRequest;
this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length);
this.onCompletion = onCompletion;
this.coordinatingIndexingPressure = coordinatingIndexingPressure;
}

@Override
Expand Down Expand Up @@ -429,9 +446,9 @@ public void onFailure(Exception exc) {
*/
private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<String, List<FieldInferenceRequest>> requestsMap) {
boolean isUpdateRequest = false;
final IndexRequest indexRequest;
final IndexRequestWithIndexingPressure indexRequest;
if (item.request() instanceof IndexRequest ir) {
indexRequest = ir;
indexRequest = new IndexRequestWithIndexingPressure(ir);
} else if (item.request() instanceof UpdateRequest updateRequest) {
isUpdateRequest = true;
if (updateRequest.script() != null) {
Expand All @@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
);
return 0;
}
indexRequest = updateRequest.doc();
indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc());
} else {
// ignore delete request
return 0;
}

final Map<String, Object> docMap = indexRequest.sourceAsMap();
final Map<String, Object> docMap = indexRequest.getIndexRequest().sourceAsMap();
long inputLength = 0;
for (var entry : fieldInferenceMap.values()) {
String field = entry.getName();
Expand Down Expand Up @@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
* This ensures that the field is treated as intentionally cleared,
* preventing any unintended carryover of prior inference results.
*/
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
return inputLength;
}

var slot = ensureResponseAccumulatorSlot(itemIndex);
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
Expand All @@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
}
continue;
}

var slot = ensureResponseAccumulatorSlot(itemIndex);
final List<String> values;
try {
Expand All @@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
List<FieldInferenceRequest> requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
int offsetAdjustment = 0;
for (String v : values) {
inputLength += v.length();
if (incrementIndexingPressure(indexRequest, itemIndex) == false) {
return inputLength;
}

if (v.isBlank()) {
slot.addOrUpdateResponse(
new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE)
Expand All @@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
requests.add(
new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings)
);
inputLength += v.length();
}

// When using the inference metadata fields format, all the input values are concatenated so that the
Expand All @@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map<
}
}
}

return inputLength;
}

private static class IndexRequestWithIndexingPressure {
private final IndexRequest indexRequest;
private boolean indexingPressureIncremented;

private IndexRequestWithIndexingPressure(IndexRequest indexRequest) {
this.indexRequest = indexRequest;
this.indexingPressureIncremented = false;
}

private IndexRequest getIndexRequest() {
return indexRequest;
}

private boolean isIndexingPressureIncremented() {
return indexingPressureIncremented;
}

private void setIndexingPressureIncremented() {
this.indexingPressureIncremented = true;
}
}

private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) {
boolean success = true;
if (indexRequest.isIndexingPressureIncremented() == false) {
try {
// Track operation count as one operation per document source update
coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed());
indexRequest.setIndexingPressureIncremented();
} catch (EsRejectedExecutionException e) {
addInferenceResponseFailure(
itemIndex,
new InferenceException(
"Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]",
e
)
);
success = false;
}
}

return success;
}

private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) {
FieldInferenceResponseAccumulator acc = inferenceResults.get(id);
if (acc == null) {
Expand Down Expand Up @@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
inferenceFieldsMap.put(fieldName, result);
}

BytesReference originalSource = indexRequest.source();
if (useLegacyFormat) {
var newDocMap = indexRequest.sourceAsMap();
for (var entry : inferenceFieldsMap.entrySet()) {
Expand All @@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons
indexRequest.source(builder);
}
}
long modifiedSourceSize = indexRequest.source().ramBytesUsed();

// Add the indexing pressure from the source modifications.
// Don't increment operation count because we count one source update as one operation, and we already accounted for those
// in addFieldInferenceRequests.
try {
coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed());
} catch (EsRejectedExecutionException e) {
indexRequest.source(originalSource, indexRequest.getContentType());
item.abort(
item.index(),
new InferenceException(
"Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]",
e
)
);
}
}
}

Expand Down
Loading