From 062793694f3a24f00b8826ccddca2bab8553b0fa Mon Sep 17 00:00:00 2001 From: Mike Pellegrini Date: Mon, 14 Apr 2025 15:55:37 -0400 Subject: [PATCH 1/2] Semantic Text Chunking Indexing Pressure (#125517) We have observed many OOMs due to the memory required to inject chunked inference results for semantic_text fields. This PR uses coordinating indexing pressure to account for this memory usage. When indexing pressure memory usage exceeds the threshold set by indexing_pressure.memory.limit, chunked inference result injection will be suspended to prevent OOMs. (cherry picked from commit 85713f78e00682e6abe3329440682ba2abef0d3f) # Conflicts: # server/src/main/java/org/elasticsearch/node/NodeConstruction.java # server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java --- docs/changelog/125517.yaml | 5 + .../elasticsearch/index/IndexingPressure.java | 8 +- .../elasticsearch/node/NodeConstruction.java | 6 +- .../node/PluginServiceInstances.java | 4 +- .../org/elasticsearch/plugins/Plugin.java | 6 + .../xpack/inference/InferencePlugin.java | 3 +- .../ShardBulkInferenceActionFilter.java | 111 +++- .../ShardBulkInferenceActionFilterTests.java | 491 +++++++++++++++++- .../mapper/SemanticTextFieldTests.java | 11 + 9 files changed, 618 insertions(+), 27 deletions(-) create mode 100644 docs/changelog/125517.yaml diff --git a/docs/changelog/125517.yaml b/docs/changelog/125517.yaml new file mode 100644 index 0000000000000..993a32960c876 --- /dev/null +++ b/docs/changelog/125517.yaml @@ -0,0 +1,5 @@ +pr: 125517 +summary: Semantic Text Chunking Indexing Pressure +area: Machine Learning +type: enhancement +issues: [] diff --git a/server/src/main/java/org/elasticsearch/index/IndexingPressure.java b/server/src/main/java/org/elasticsearch/index/IndexingPressure.java index bc64df9e7bc5d..93d45b1092d22 100644 --- a/server/src/main/java/org/elasticsearch/index/IndexingPressure.java +++ b/server/src/main/java/org/elasticsearch/index/IndexingPressure.java @@ -146,11 +146,15 @@ public Incremental startIncrementalCoordinating(int operations, long bytes, bool } public Coordinating markCoordinatingOperationStarted(int operations, long bytes, boolean forceExecution) { - Coordinating coordinating = new Coordinating(forceExecution); + Coordinating coordinating = createCoordinatingOperation(forceExecution); coordinating.increment(operations, bytes); return coordinating; } + public Coordinating createCoordinatingOperation(boolean forceExecution) { + return new Coordinating(forceExecution); + } + public class Incremental implements Releasable { private final AtomicBoolean closed = new AtomicBoolean(); @@ -243,7 +247,7 @@ public Coordinating(boolean forceExecution) { this.forceExecution = forceExecution; } - private void increment(int operations, long bytes) { + public void increment(int operations, long bytes) { assert closed.get() == false; long combinedBytes = currentCombinedCoordinatingAndPrimaryBytes.addAndGet(bytes); long replicaWriteBytes = currentReplicaBytes.get(); diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index f2b4d4759c72f..349b50820705c 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -936,6 +936,8 @@ public Map queryFields() { metadataCreateIndexService ); + final IndexingPressure indexingLimits = new IndexingPressure(settings); + PluginServiceInstances pluginServices = new PluginServiceInstances( client, clusterService, @@ -957,7 +959,8 @@ public Map queryFields() { dataStreamGlobalRetentionSettings, documentParsingProvider, taskManager, - slowLogFieldProvider + slowLogFieldProvider, + indexingLimits ); Collection pluginComponents = pluginsService.flatMap(plugin -> { @@ -990,7 +993,6 @@ public Map queryFields() { .map(TerminationHandlerProvider::handler); terminationHandler = getSinglePlugin(terminationHandlers, TerminationHandler.class).orElse(null); - final IndexingPressure indexingLimits = new IndexingPressure(settings); final IncrementalBulkService incrementalBulkService = new IncrementalBulkService(client, indexingLimits); ActionModule actionModule = new ActionModule( diff --git a/server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java b/server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java index 2dd52fa863d2f..768707c49439e 100644 --- a/server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java +++ b/server/src/main/java/org/elasticsearch/node/PluginServiceInstances.java @@ -19,6 +19,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.env.NodeEnvironment; import org.elasticsearch.features.FeatureService; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.SlowLogFieldProvider; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.SystemIndices; @@ -53,5 +54,6 @@ public record PluginServiceInstances( DataStreamGlobalRetentionSettings dataStreamGlobalRetentionSettings, DocumentParsingProvider documentParsingProvider, TaskManager taskManager, - SlowLogFieldProvider slowLogFieldProvider + SlowLogFieldProvider slowLogFieldProvider, + IndexingPressure indexingPressure ) implements Plugin.PluginServices {} diff --git a/server/src/main/java/org/elasticsearch/plugins/Plugin.java b/server/src/main/java/org/elasticsearch/plugins/Plugin.java index 28aa5f51297a6..a43f80de098fc 100644 --- a/server/src/main/java/org/elasticsearch/plugins/Plugin.java +++ b/server/src/main/java/org/elasticsearch/plugins/Plugin.java @@ -27,6 +27,7 @@ import org.elasticsearch.features.FeatureService; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettingProvider; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.SlowLogFieldProvider; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.indices.SystemIndices; @@ -179,6 +180,11 @@ public interface PluginServices { * Provider for additional SlowLog fields */ SlowLogFieldProvider slowLogFieldProvider(); + + /** + * Provider for indexing pressure + */ + IndexingPressure indexingPressure(); } /** diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index f7d3f791d79d1..a31f04c411164 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -328,7 +328,8 @@ public Collection createComponents(PluginServices services) { services.clusterService(), serviceRegistry, modelRegistry.get(), - getLicenseState() + getLicenseState(), + services.indexingPressure() ); shardBulkInferenceActionFilter.set(actionFilter); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index ed5342b2ccbf1..e4ace30a9776d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -28,11 +28,13 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; @@ -108,18 +110,21 @@ public class ShardBulkInferenceActionFilter implements MappedActionFilter { private final InferenceServiceRegistry inferenceServiceRegistry; private final ModelRegistry modelRegistry; private final XPackLicenseState licenseState; + private final IndexingPressure indexingPressure; private volatile long batchSizeInBytes; public ShardBulkInferenceActionFilter( ClusterService clusterService, InferenceServiceRegistry inferenceServiceRegistry, ModelRegistry modelRegistry, - XPackLicenseState licenseState + XPackLicenseState licenseState, + IndexingPressure indexingPressure ) { this.clusterService = clusterService; this.inferenceServiceRegistry = inferenceServiceRegistry; this.modelRegistry = modelRegistry; this.licenseState = licenseState; + this.indexingPressure = indexingPressure; this.batchSizeInBytes = INDICES_INFERENCE_BATCH_SIZE.get(clusterService.getSettings()).getBytes(); clusterService.getClusterSettings().addSettingsUpdateConsumer(INDICES_INFERENCE_BATCH_SIZE, this::setBatchSize); } @@ -145,8 +150,15 @@ public void app BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeInferenceFieldMap(); if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { - Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); - processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); + // Maintain coordinating indexing pressure from inference until the indexing operations are complete + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.createCoordinatingOperation(false); + Runnable onInferenceCompletion = () -> chain.proceed( + task, + action, + request, + ActionListener.releaseAfter(listener, coordinatingIndexingPressure) + ); + processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion, coordinatingIndexingPressure); return; } } @@ -156,11 +168,13 @@ public void app private void processBulkShardRequest( Map fieldInferenceMap, BulkShardRequest bulkShardRequest, - Runnable onCompletion + Runnable onCompletion, + IndexingPressure.Coordinating coordinatingIndexingPressure ) { var index = clusterService.state().getMetadata().index(bulkShardRequest.index()); boolean useLegacyFormat = InferenceMetadataFieldsMapper.isEnabled(index.getSettings()) == false; - new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion).run(); + new AsyncBulkShardInferenceAction(useLegacyFormat, fieldInferenceMap, bulkShardRequest, onCompletion, coordinatingIndexingPressure) + .run(); } private record InferenceProvider(InferenceService service, Model model) {} @@ -230,18 +244,21 @@ private class AsyncBulkShardInferenceAction implements Runnable { private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; + private final IndexingPressure.Coordinating coordinatingIndexingPressure; private AsyncBulkShardInferenceAction( boolean useLegacyFormat, Map fieldInferenceMap, BulkShardRequest bulkShardRequest, - Runnable onCompletion + Runnable onCompletion, + IndexingPressure.Coordinating coordinatingIndexingPressure ) { this.useLegacyFormat = useLegacyFormat; this.fieldInferenceMap = fieldInferenceMap; this.bulkShardRequest = bulkShardRequest; this.inferenceResults = new AtomicArray<>(bulkShardRequest.items().length); this.onCompletion = onCompletion; + this.coordinatingIndexingPressure = coordinatingIndexingPressure; } @Override @@ -429,9 +446,9 @@ public void onFailure(Exception exc) { */ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map> requestsMap) { boolean isUpdateRequest = false; - final IndexRequest indexRequest; + final IndexRequestWithIndexingPressure indexRequest; if (item.request() instanceof IndexRequest ir) { - indexRequest = ir; + indexRequest = new IndexRequestWithIndexingPressure(ir); } else if (item.request() instanceof UpdateRequest updateRequest) { isUpdateRequest = true; if (updateRequest.script() != null) { @@ -445,13 +462,13 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< ); return 0; } - indexRequest = updateRequest.doc(); + indexRequest = new IndexRequestWithIndexingPressure(updateRequest.doc()); } else { // ignore delete request return 0; } - final Map docMap = indexRequest.sourceAsMap(); + final Map docMap = indexRequest.getIndexRequest().sourceAsMap(); long inputLength = 0; for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); @@ -487,6 +504,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< * This ensures that the field is treated as intentionally cleared, * preventing any unintended carryover of prior inference results. */ + if (incrementIndexingPressure(indexRequest, itemIndex) == false) { + return inputLength; + } + var slot = ensureResponseAccumulatorSlot(itemIndex); slot.addOrUpdateResponse( new FieldInferenceResponse(field, sourceField, null, order++, 0, null, EMPTY_CHUNKED_INFERENCE) @@ -508,6 +529,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< } continue; } + var slot = ensureResponseAccumulatorSlot(itemIndex); final List values; try { @@ -525,7 +547,10 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< List requests = requestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); int offsetAdjustment = 0; for (String v : values) { - inputLength += v.length(); + if (incrementIndexingPressure(indexRequest, itemIndex) == false) { + return inputLength; + } + if (v.isBlank()) { slot.addOrUpdateResponse( new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) @@ -534,6 +559,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< requests.add( new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) ); + inputLength += v.length(); } // When using the inference metadata fields format, all the input values are concatenated so that the @@ -543,9 +569,54 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< } } } + return inputLength; } + private static class IndexRequestWithIndexingPressure { + private final IndexRequest indexRequest; + private boolean indexingPressureIncremented; + + private IndexRequestWithIndexingPressure(IndexRequest indexRequest) { + this.indexRequest = indexRequest; + this.indexingPressureIncremented = false; + } + + private IndexRequest getIndexRequest() { + return indexRequest; + } + + private boolean isIndexingPressureIncremented() { + return indexingPressureIncremented; + } + + private void setIndexingPressureIncremented() { + this.indexingPressureIncremented = true; + } + } + + private boolean incrementIndexingPressure(IndexRequestWithIndexingPressure indexRequest, int itemIndex) { + boolean success = true; + if (indexRequest.isIndexingPressureIncremented() == false) { + try { + // Track operation count as one operation per document source update + coordinatingIndexingPressure.increment(1, indexRequest.getIndexRequest().source().ramBytesUsed()); + indexRequest.setIndexingPressureIncremented(); + } catch (EsRejectedExecutionException e) { + addInferenceResponseFailure( + itemIndex, + new InferenceException( + "Insufficient memory available to update source on document [" + indexRequest.getIndexRequest().id() + "]", + e + ) + ); + success = false; + } + } + + return success; + } + private FieldInferenceResponseAccumulator ensureResponseAccumulatorSlot(int id) { FieldInferenceResponseAccumulator acc = inferenceResults.get(id); if (acc == null) { @@ -622,6 +693,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons inferenceFieldsMap.put(fieldName, result); } + BytesReference originalSource = indexRequest.source(); if (useLegacyFormat) { var newDocMap = indexRequest.sourceAsMap(); for (var entry : inferenceFieldsMap.entrySet()) { @@ -634,6 +706,23 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons indexRequest.source(builder); } } + long modifiedSourceSize = indexRequest.source().ramBytesUsed(); + + // Add the indexing pressure from the source modifications. + // Don't increment operation count because we count one source update as one operation, and we already accounted for those + // in addFieldInferenceRequests. + try { + coordinatingIndexingPressure.increment(0, modifiedSourceSize - originalSource.ramBytesUsed()); + } catch (EsRejectedExecutionException e) { + indexRequest.source(originalSource, indexRequest.getContentType()); + item.abort( + item.index(), + new InferenceException( + "Insufficient memory available to insert inference results into document [" + indexRequest.id() + "]", + e + ) + ); + } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 99a79b81cabac..669900cc5d710 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.bulk.BulkItemRequest; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkShardRequest; +import org.elasticsearch.action.bulk.BulkShardResponse; import org.elasticsearch.action.bulk.TransportShardBulkAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; @@ -26,19 +27,24 @@ import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.CheckedBiFunction; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; @@ -48,6 +54,8 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.XPackField; @@ -72,13 +80,16 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import static org.elasticsearch.index.IndexingPressure.MAX_COORDINATING_BYTES; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.awaitLatch; +import static org.elasticsearch.xcontent.ToXContent.EMPTY_PARAMS; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.INDICES_INFERENCE_BATCH_SIZE; import static org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter.getIndexRequestOrNull; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getOriginalTextFieldName; -import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapperTests.addSemanticTextInferenceResults; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbedding; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticTextInput; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.semanticTextFieldFromChunkedInferenceResults; @@ -87,13 +98,24 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.longThat; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class ShardBulkInferenceActionFilterTests extends ESTestCase { private static final Object EXPLICIT_NULL = new Object(); + private static final IndexingPressure NOOP_INDEXING_PRESSURE = new NoopIndexingPressure(); private final boolean useLegacyFormat; private ThreadPool threadPool; @@ -119,7 +141,7 @@ public void tearDownThreadPool() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testFilterNoop() throws Exception { - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -145,7 +167,7 @@ public void testFilterNoop() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testLicenseInvalidForInference() throws InterruptedException { StaticModel model = StaticModel.createRandomInstance(); - ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), useLegacyFormat, false); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, Map.of(), NOOP_INDEXING_PRESSURE, useLegacyFormat, false); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -186,6 +208,7 @@ public void testInferenceNotFound() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), + NOOP_INDEXING_PRESSURE, useLegacyFormat, true ); @@ -227,16 +250,17 @@ public void testInferenceNotFound() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testItemFailures() throws Exception { - StaticModel model = StaticModel.createRandomInstance(); + StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), + NOOP_INDEXING_PRESSURE, useLegacyFormat, true ); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); + model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -295,13 +319,14 @@ public void testItemFailures() throws Exception { @SuppressWarnings({ "unchecked", "rawtypes" }) public void testExplicitNull() throws Exception { - StaticModel model = StaticModel.createRandomInstance(); + StaticModel model = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); model.putResult("I am a failure", new ChunkedInferenceError(new IllegalArgumentException("boom"))); - model.putResult("I am a success", randomChunkedInferenceEmbeddingSparse(List.of("I am a success"))); + model.putResult("I am a success", randomChunkedInferenceEmbedding(model, List.of("I am a success"))); ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), + NOOP_INDEXING_PRESSURE, useLegacyFormat, true ); @@ -372,6 +397,7 @@ public void testHandleEmptyInput() throws Exception { ShardBulkInferenceActionFilter filter = createFilter( threadPool, Map.of(model.getInferenceEntityId(), model), + NOOP_INDEXING_PRESSURE, useLegacyFormat, true ); @@ -444,7 +470,7 @@ public void testManyRandomDocs() throws Exception { modifiedRequests[id] = res[1]; } - ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, useLegacyFormat, true); + ShardBulkInferenceActionFilter filter = createFilter(threadPool, inferenceModelMap, NOOP_INDEXING_PRESSURE, useLegacyFormat, true); CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { @@ -474,10 +500,397 @@ public void testManyRandomDocs() throws Exception { awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } + @SuppressWarnings({ "unchecked", "rawtypes" }) + public void testIndexingPressure() throws Exception { + final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure(Settings.EMPTY); + final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final StaticModel denseModel = StaticModel.createRandomInstance(TaskType.TEXT_EMBEDDING); + final ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(sparseModel.getInferenceEntityId(), sparseModel, denseModel.getInferenceEntityId(), denseModel), + indexingPressure, + useLegacyFormat, + true + ); + + XContentBuilder doc0Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "a test value"); + XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", "another test value"); + XContentBuilder doc2Source = IndexRequest.getXContentBuilder( + XContentType.JSON, + "sparse_field", + "a test value", + "dense_field", + "another test value" + ); + XContentBuilder doc3Source = IndexRequest.getXContentBuilder( + XContentType.JSON, + "dense_field", + List.of("value one", " ", "value two") + ); + XContentBuilder doc4Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", " "); + XContentBuilder doc5Source = XContentFactory.contentBuilder(XContentType.JSON); + { + doc5Source.startObject(); + if (useLegacyFormat == false) { + doc5Source.field("sparse_field", "a test value"); + } + addSemanticTextInferenceResults( + useLegacyFormat, + doc5Source, + List.of(randomSemanticText(useLegacyFormat, "sparse_field", sparseModel, null, List.of("a test value"), XContentType.JSON)) + ); + doc5Source.endObject(); + } + XContentBuilder doc0UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "an updated value"); + XContentBuilder doc1UpdateSource = IndexRequest.getXContentBuilder(XContentType.JSON, "dense_field", null); + + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + BulkShardRequest bulkShardRequest = (BulkShardRequest) request; + assertNull(bulkShardRequest.getInferenceFieldMap()); + assertThat(bulkShardRequest.items().length, equalTo(10)); + + for (BulkItemRequest item : bulkShardRequest.items()) { + assertNull(item.getPrimaryResponse()); + } + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc3Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc4Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc0UpdateSource)); + if (useLegacyFormat == false) { + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1UpdateSource)); + } + + verify(coordinatingIndexingPressure, times(useLegacyFormat ? 6 : 7)).increment(eq(0), longThat(l -> l > 0)); + + // Verify that the only times that increment is called are the times verified above + verify(coordinatingIndexingPressure, times(useLegacyFormat ? 12 : 14)).increment(anyInt(), anyLong()); + + // Verify that the coordinating indexing pressure is maintained through downstream action filters + verify(coordinatingIndexingPressure, never()).close(); + + // Call the listener once the request is successfully processed, like is done in the production code path + listener.onResponse(null); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "sparse_field", + new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null), + "dense_field", + new InferenceFieldMetadata("dense_field", denseModel.getInferenceEntityId(), new String[] { "dense_field" }, null) + ); + + BulkItemRequest[] items = new BulkItemRequest[10]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source(doc0Source)); + items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); + items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source)); + items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source(doc3Source)); + items[4] = new BulkItemRequest(4, new IndexRequest("index").id("doc_4").source(doc4Source)); + items[5] = new BulkItemRequest(5, new IndexRequest("index").id("doc_5").source(doc5Source)); + items[6] = new BulkItemRequest(6, new IndexRequest("index").id("doc_6").source("non_inference_field", "yet another test value")); + items[7] = new BulkItemRequest(7, new UpdateRequest().doc(new IndexRequest("index").id("doc_0").source(doc0UpdateSource))); + items[8] = new BulkItemRequest(8, new UpdateRequest().doc(new IndexRequest("index").id("doc_1").source(doc1UpdateSource))); + items[9] = new BulkItemRequest( + 9, + new UpdateRequest().doc(new IndexRequest("index").id("doc_3").source("non_inference_field", "yet another updated value")) + ); + + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).close(); + } + + @SuppressWarnings("unchecked") + public void testIndexingPressureTripsOnInferenceRequestGeneration() throws Exception { + final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( + Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), "1b").build() + ); + final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(sparseModel.getInferenceEntityId(), sparseModel), + indexingPressure, + useLegacyFormat, + true + ); + + XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); + + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(request.getInferenceFieldMap()); + assertThat(request.items().length, equalTo(3)); + + assertNull(request.items()[0].getPrimaryResponse()); + assertNull(request.items()[2].getPrimaryResponse()); + + BulkItemRequest doc1Request = request.items()[1]; + BulkItemResponse doc1Response = doc1Request.getPrimaryResponse(); + assertNotNull(doc1Response); + assertTrue(doc1Response.isFailed()); + BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); + assertThat( + doc1Failure.getCause().getMessage(), + containsString("Insufficient memory available to update source on document [doc_1]") + ); + assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); + assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); + + IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request()); + assertThat(doc1IndexRequest, notNullValue()); + assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source))); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); + verify(coordinatingIndexingPressure, times(1)).increment(anyInt(), anyLong()); + + // Verify that the coordinating indexing pressure is maintained through downstream action filters + verify(coordinatingIndexingPressure, never()).close(); + + // Call the listener once the request is successfully processed, like is done in the production code path + listener.onResponse(null); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = (ActionListener) mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "sparse_field", + new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null) + ); + + BulkItemRequest[] items = new BulkItemRequest[3]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); + items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz")); + + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).close(); + } + + @SuppressWarnings("unchecked") + public void testIndexingPressureTripsOnInferenceResponseHandling() throws Exception { + final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); + final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( + Settings.builder().put(MAX_COORDINATING_BYTES.getKey(), (bytesUsed(doc1Source) + 1) + "b").build() + ); + + final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + sparseModel.putResult("bar", randomChunkedInferenceEmbedding(sparseModel, List.of("bar"))); + + final ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(sparseModel.getInferenceEntityId(), sparseModel), + indexingPressure, + useLegacyFormat, + true + ); + + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(request.getInferenceFieldMap()); + assertThat(request.items().length, equalTo(3)); + + assertNull(request.items()[0].getPrimaryResponse()); + assertNull(request.items()[2].getPrimaryResponse()); + + BulkItemRequest doc1Request = request.items()[1]; + BulkItemResponse doc1Response = doc1Request.getPrimaryResponse(); + assertNotNull(doc1Response); + assertTrue(doc1Response.isFailed()); + BulkItemResponse.Failure doc1Failure = doc1Response.getFailure(); + assertThat( + doc1Failure.getCause().getMessage(), + containsString("Insufficient memory available to insert inference results into document [doc_1]") + ); + assertThat(doc1Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); + assertThat(doc1Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); + + IndexRequest doc1IndexRequest = getIndexRequestOrNull(doc1Request.request()); + assertThat(doc1IndexRequest, notNullValue()); + assertThat(doc1IndexRequest.source(), equalTo(BytesReference.bytes(doc1Source))); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); + verify(coordinatingIndexingPressure).increment(eq(0), longThat(l -> l > 0)); + verify(coordinatingIndexingPressure, times(2)).increment(anyInt(), anyLong()); + + // Verify that the coordinating indexing pressure is maintained through downstream action filters + verify(coordinatingIndexingPressure, never()).close(); + + // Call the listener once the request is successfully processed, like is done in the production code path + listener.onResponse(null); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = (ActionListener) mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "sparse_field", + new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null) + ); + + BulkItemRequest[] items = new BulkItemRequest[3]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); + items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source("non_inference_field", "baz")); + + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).close(); + } + + @SuppressWarnings("unchecked") + public void testIndexingPressurePartialFailure() throws Exception { + // Use different length strings so that doc 1 and doc 2 sources are different sizes + final XContentBuilder doc1Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bar"); + final XContentBuilder doc2Source = IndexRequest.getXContentBuilder(XContentType.JSON, "sparse_field", "bazzz"); + + final StaticModel sparseModel = StaticModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final ChunkedInferenceEmbedding barEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bar")); + final ChunkedInferenceEmbedding bazzzEmbedding = randomChunkedInferenceEmbedding(sparseModel, List.of("bazzz")); + sparseModel.putResult("bar", barEmbedding); + sparseModel.putResult("bazzz", bazzzEmbedding); + + CheckedBiFunction, ChunkedInference, Long, IOException> estimateInferenceResultsBytes = (inputs, inference) -> { + SemanticTextField semanticTextField = semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + "sparse_field", + sparseModel, + null, + inputs, + inference, + XContentType.JSON + ); + XContentBuilder builder = XContentFactory.jsonBuilder(); + semanticTextField.toXContent(builder, EMPTY_PARAMS); + return bytesUsed(builder); + }; + + final InstrumentedIndexingPressure indexingPressure = new InstrumentedIndexingPressure( + Settings.builder() + .put( + MAX_COORDINATING_BYTES.getKey(), + (bytesUsed(doc1Source) + bytesUsed(doc2Source) + estimateInferenceResultsBytes.apply(List.of("bar"), barEmbedding) + + (estimateInferenceResultsBytes.apply(List.of("bazzz"), bazzzEmbedding) / 2)) + "b" + ) + .build() + ); + + final ShardBulkInferenceActionFilter filter = createFilter( + threadPool, + Map.of(sparseModel.getInferenceEntityId(), sparseModel), + indexingPressure, + useLegacyFormat, + true + ); + + CountDownLatch chainExecuted = new CountDownLatch(1); + ActionFilterChain actionFilterChain = (task, action, request, listener) -> { + try { + assertNull(request.getInferenceFieldMap()); + assertThat(request.items().length, equalTo(4)); + + assertNull(request.items()[0].getPrimaryResponse()); + assertNull(request.items()[1].getPrimaryResponse()); + assertNull(request.items()[3].getPrimaryResponse()); + + BulkItemRequest doc2Request = request.items()[2]; + BulkItemResponse doc2Response = doc2Request.getPrimaryResponse(); + assertNotNull(doc2Response); + assertTrue(doc2Response.isFailed()); + BulkItemResponse.Failure doc2Failure = doc2Response.getFailure(); + assertThat( + doc2Failure.getCause().getMessage(), + containsString("Insufficient memory available to insert inference results into document [doc_2]") + ); + assertThat(doc2Failure.getCause().getCause(), instanceOf(EsRejectedExecutionException.class)); + assertThat(doc2Failure.getStatus(), is(RestStatus.TOO_MANY_REQUESTS)); + + IndexRequest doc2IndexRequest = getIndexRequestOrNull(doc2Request.request()); + assertThat(doc2IndexRequest, notNullValue()); + assertThat(doc2IndexRequest.source(), equalTo(BytesReference.bytes(doc2Source))); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc1Source)); + verify(coordinatingIndexingPressure).increment(1, bytesUsed(doc2Source)); + verify(coordinatingIndexingPressure, times(2)).increment(eq(0), longThat(l -> l > 0)); + verify(coordinatingIndexingPressure, times(4)).increment(anyInt(), anyLong()); + + // Verify that the coordinating indexing pressure is maintained through downstream action filters + verify(coordinatingIndexingPressure, never()).close(); + + // Call the listener once the request is successfully processed, like is done in the production code path + listener.onResponse(null); + } finally { + chainExecuted.countDown(); + } + }; + ActionListener actionListener = (ActionListener) mock(ActionListener.class); + Task task = mock(Task.class); + + Map inferenceFieldMap = Map.of( + "sparse_field", + new InferenceFieldMetadata("sparse_field", sparseModel.getInferenceEntityId(), new String[] { "sparse_field" }, null) + ); + + BulkItemRequest[] items = new BulkItemRequest[4]; + items[0] = new BulkItemRequest(0, new IndexRequest("index").id("doc_0").source("non_inference_field", "foo")); + items[1] = new BulkItemRequest(1, new IndexRequest("index").id("doc_1").source(doc1Source)); + items[2] = new BulkItemRequest(2, new IndexRequest("index").id("doc_2").source(doc2Source)); + items[3] = new BulkItemRequest(3, new IndexRequest("index").id("doc_3").source("non_inference_field", "baz")); + + BulkShardRequest request = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, items); + request.setInferenceFieldMap(inferenceFieldMap); + filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); + awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); + + IndexingPressure.Coordinating coordinatingIndexingPressure = indexingPressure.getCoordinating(); + assertThat(coordinatingIndexingPressure, notNullValue()); + verify(coordinatingIndexingPressure).close(); + } + @SuppressWarnings("unchecked") private static ShardBulkInferenceActionFilter createFilter( ThreadPool threadPool, Map modelMap, + IndexingPressure indexingPressure, boolean useLegacyFormat, boolean isLicenseValidForInference ) { @@ -503,6 +916,17 @@ private static ShardBulkInferenceActionFilter createFilter( }; doAnswer(unparsedModelAnswer).when(modelRegistry).getModelWithSecrets(any(), any()); + Answer minimalServiceSettingsAnswer = invocationOnMock -> { + String inferenceId = (String) invocationOnMock.getArguments()[0]; + var model = modelMap.get(inferenceId); + if (model == null) { + throw new ResourceNotFoundException("model id [{}] not found", inferenceId); + } + + return new MinimalServiceSettings(model); + }; + doAnswer(minimalServiceSettingsAnswer).when(modelRegistry).getMinimalServiceSettings(any()); + InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; @@ -544,7 +968,8 @@ private static ShardBulkInferenceActionFilter createFilter( createClusterService(useLegacyFormat), inferenceServiceRegistry, modelRegistry, - licenseState + licenseState, + indexingPressure ); } @@ -629,6 +1054,10 @@ private static BulkItemRequest[] randomBulkItemRequest( new BulkItemRequest(requestId, new IndexRequest("index").source(expectedDocMap, requestContentType)) }; } + private static long bytesUsed(XContentBuilder builder) { + return BytesReference.bytes(builder).ramBytesUsed(); + } + @SuppressWarnings({ "unchecked" }) private static void assertInferenceResults( boolean useLegacyFormat, @@ -693,7 +1122,11 @@ private static class StaticModel extends TestModel { } public static StaticModel createRandomInstance() { - TestModel testModel = TestModel.createRandomInstance(); + return createRandomInstance(randomFrom(TaskType.TEXT_EMBEDDING, TaskType.SPARSE_EMBEDDING)); + } + + public static StaticModel createRandomInstance(TaskType taskType) { + TestModel testModel = TestModel.createRandomInstance(taskType); return new StaticModel( testModel.getInferenceEntityId(), testModel.getTaskType(), @@ -716,4 +1149,42 @@ boolean hasResult(String text) { return resultMap.containsKey(text); } } + + private static class InstrumentedIndexingPressure extends IndexingPressure { + private Coordinating coordinating = null; + + private InstrumentedIndexingPressure(Settings settings) { + super(settings); + } + + private Coordinating getCoordinating() { + return coordinating; + } + + @Override + public Coordinating createCoordinatingOperation(boolean forceExecution) { + coordinating = spy(super.createCoordinatingOperation(forceExecution)); + return coordinating; + } + } + + private static class NoopIndexingPressure extends IndexingPressure { + private NoopIndexingPressure() { + super(Settings.EMPTY); + } + + @Override + public Coordinating createCoordinatingOperation(boolean forceExecution) { + return new NoopCoordinating(forceExecution); + } + + private class NoopCoordinating extends Coordinating { + private NoopCoordinating(boolean forceExecution) { + super(forceExecution); + } + + @Override + public void increment(int operations, long bytes) {} + } + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 6e90e9d3af12e..8c1b2321f94fa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -185,6 +185,17 @@ public void testModelSettingsValidation() { assertThat(ex.getMessage(), containsString("required [element_type] field is missing")); } + public static ChunkedInferenceEmbedding randomChunkedInferenceEmbedding(Model model, List inputs) { + return switch (model.getTaskType()) { + case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); + case TEXT_EMBEDDING -> switch (model.getServiceSettings().elementType()) { + case FLOAT -> randomChunkedInferenceEmbeddingFloat(model, inputs); + case BIT, BYTE -> randomChunkedInferenceEmbeddingByte(model, inputs); + }; + default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); + }; + } + public static ChunkedInferenceEmbedding randomChunkedInferenceEmbeddingByte(Model model, List inputs) { DenseVectorFieldMapper.ElementType elementType = model.getServiceSettings().elementType(); int embeddingLength = DenseVectorFieldMapperTestUtils.getEmbeddingLength(elementType, model.getServiceSettings().dimensions()); From 53b541531042b42497789c6e53da0a4baea86fa0 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Mon, 28 Apr 2025 12:39:16 +0000 Subject: [PATCH 2/2] [CI] Auto commit changes from spotless --- .../action/filter/ShardBulkInferenceActionFilterTests.java | 1 - 1 file changed, 1 deletion(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 669900cc5d710..df860780366ac 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -101,7 +101,6 @@ import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.longThat; import static org.mockito.Mockito.any;