From 4739c5fbb5f0bc2165810bc70c0743286c15fd9d Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 4 Apr 2025 15:09:54 +1300 Subject: [PATCH 1/7] [ML] Handle new `actual_memory_usage_bytes` field in model size stats. The `actual_memory_usage_bytes` field represents the real, physical memory allocated to the `autodetect` process as reported by the OS. Reporting this value in the model size stats associated with an AD job is useful, especially in OOM situations. --- .../org/elasticsearch/TransportVersions.java | 1 + .../xpack/core/ml/MlConfigVersion.java | 3 +- .../autodetect/state/ModelSizeStats.java | 35 ++++++++++++++++++- .../job/persistence/JobResultsProvider.java | 7 ++++ .../autodetect/AutodetectProcessManager.java | 2 ++ .../AutodetectProcessManagerTests.java | 3 ++ 6 files changed, 49 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 178aee182dc42..8eec3490c8313 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -213,6 +213,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOTE_EXCEPTION = def(9_044_0_00); public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); + public static final TransportVersion ML_AD_ACTUAL_MEMORY_USAGE = def(9_047_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java index 260409db0e653..06e5ff5b8d08c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java @@ -153,12 +153,13 @@ private static void checkUniqueness(int id, String uniqueId) { // V_11 is used in ELSER v2 package configs public static final MlConfigVersion V_11 = registerMlConfigVersion(11_00_0_0_99, "79CB2950-57C7-11EE-AE5D-0800200C9A66"); public static final MlConfigVersion V_12 = registerMlConfigVersion(12_00_0_0_99, "Trained model config prefix strings added"); + public static final MlConfigVersion V_13 = registerMlConfigVersion(13_00_0_0_99, "Anomaly Detection reports actual memory usage"); /** * Reference to the most recent Ml config version. * This should be the Ml config version with the highest id. */ - public static final MlConfigVersion CURRENT = V_12; + public static final MlConfigVersion CURRENT = V_13; /** * Reference to the first MlConfigVersion that is detached from the diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java index 3a9123445e697..0757c9424b309 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java @@ -41,6 +41,7 @@ public class ModelSizeStats implements ToXContentObject, Writeable { */ public static final ParseField MODEL_BYTES_FIELD = new ParseField("model_bytes"); public static final ParseField PEAK_MODEL_BYTES_FIELD = new ParseField("peak_model_bytes"); + public static final ParseField ACTUAL_MEMORY_USAGE_BYTES = new ParseField("actual_memory_usage_bytes"); public static final ParseField MODEL_BYTES_EXCEEDED_FIELD = new ParseField("model_bytes_exceeded"); public static final ParseField MODEL_BYTES_MEMORY_LIMIT_FIELD = new ParseField("model_bytes_memory_limit"); public static final ParseField TOTAL_BY_FIELD_COUNT_FIELD = new ParseField("total_by_field_count"); @@ -74,6 +75,7 @@ private static ConstructingObjectParser createParser(boolean igno parser.declareString((modelSizeStat, s) -> {}, Result.RESULT_TYPE); parser.declareLong(Builder::setModelBytes, MODEL_BYTES_FIELD); parser.declareLong(Builder::setPeakModelBytes, PEAK_MODEL_BYTES_FIELD); + parser.declareLong(Builder::setActualMemoryUsageBytes, ACTUAL_MEMORY_USAGE_BYTES); parser.declareLong(Builder::setModelBytesExceeded, MODEL_BYTES_EXCEEDED_FIELD); parser.declareLong(Builder::setModelBytesMemoryLimit, MODEL_BYTES_MEMORY_LIMIT_FIELD); parser.declareLong(Builder::setBucketAllocationFailuresCount, BUCKET_ALLOCATION_FAILURES_COUNT_FIELD); @@ -152,6 +154,7 @@ public String toString() { * 1. The job's model_memory_limit * 2. The current model memory, i.e. what's reported in model_bytes of this object * 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object + * 4. The actual memory usage, i.e. what's reported in actual_memory_usage_bytes of this object * The field storing this enum can also be null, which means the * assignment code will decide on the fly - this was the old behaviour prior * to 7.11. @@ -159,7 +162,8 @@ public String toString() { public enum AssignmentMemoryBasis implements Writeable { MODEL_MEMORY_LIMIT, CURRENT_MODEL_BYTES, - PEAK_MODEL_BYTES; + PEAK_MODEL_BYTES, + ACTUAL_MEMORY_USAGE_BYTES; public static AssignmentMemoryBasis fromString(String statusName) { return valueOf(statusName.trim().toUpperCase(Locale.ROOT)); @@ -183,6 +187,7 @@ public String toString() { private final String jobId; private final long modelBytes; private final Long peakModelBytes; + private final Long actualMemoryUsageBytes; private final Long modelBytesExceeded; private final Long modelBytesMemoryLimit; private final long totalByFieldCount; @@ -206,6 +211,7 @@ private ModelSizeStats( String jobId, long modelBytes, Long peakModelBytes, + Long actualMemoryUsageBytes, Long modelBytesExceeded, Long modelBytesMemoryLimit, long totalByFieldCount, @@ -228,6 +234,7 @@ private ModelSizeStats( this.jobId = jobId; this.modelBytes = modelBytes; this.peakModelBytes = peakModelBytes; + this.actualMemoryUsageBytes = actualMemoryUsageBytes; this.modelBytesExceeded = modelBytesExceeded; this.modelBytesMemoryLimit = modelBytesMemoryLimit; this.totalByFieldCount = totalByFieldCount; @@ -252,6 +259,11 @@ public ModelSizeStats(StreamInput in) throws IOException { jobId = in.readString(); modelBytes = in.readVLong(); peakModelBytes = in.readOptionalLong(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) { + actualMemoryUsageBytes = in.readOptionalLong(); + } else { + actualMemoryUsageBytes = null; + } modelBytesExceeded = in.readOptionalLong(); modelBytesMemoryLimit = in.readOptionalLong(); totalByFieldCount = in.readVLong(); @@ -293,6 +305,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(jobId); out.writeVLong(modelBytes); out.writeOptionalLong(peakModelBytes); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) { + out.writeOptionalLong(actualMemoryUsageBytes); + } out.writeOptionalLong(modelBytesExceeded); out.writeOptionalLong(modelBytesMemoryLimit); out.writeVLong(totalByFieldCount); @@ -339,6 +354,9 @@ public XContentBuilder doXContentBody(XContentBuilder builder) throws IOExceptio if (peakModelBytes != null) { builder.field(PEAK_MODEL_BYTES_FIELD.getPreferredName(), peakModelBytes); } + if (actualMemoryUsageBytes != null) { + builder.field(ACTUAL_MEMORY_USAGE_BYTES.getPreferredName(), actualMemoryUsageBytes); + } if (modelBytesExceeded != null) { builder.field(MODEL_BYTES_EXCEEDED_FIELD.getPreferredName(), modelBytesExceeded); } @@ -391,6 +409,10 @@ public Long getPeakModelBytes() { return peakModelBytes; } + public Long getActualMemoryUsageBytes() { + return actualMemoryUsageBytes; + } + public Long getModelBytesExceeded() { return modelBytesExceeded; } @@ -479,6 +501,7 @@ public int hashCode() { jobId, modelBytes, peakModelBytes, + actualMemoryUsageBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, @@ -517,6 +540,7 @@ public boolean equals(Object other) { return this.modelBytes == that.modelBytes && Objects.equals(this.peakModelBytes, that.peakModelBytes) + && this.actualMemoryUsageBytes == that.actualMemoryUsageBytes && Objects.equals(this.modelBytesExceeded, that.modelBytesExceeded) && Objects.equals(this.modelBytesMemoryLimit, that.modelBytesMemoryLimit) && this.totalByFieldCount == that.totalByFieldCount @@ -543,6 +567,7 @@ public static class Builder { private final String jobId; private long modelBytes; private Long peakModelBytes; + private Long actualMemoryUsageBytes; private Long modelBytesExceeded; private Long modelBytesMemoryLimit; private long totalByFieldCount; @@ -573,6 +598,7 @@ public Builder(ModelSizeStats modelSizeStats) { this.jobId = modelSizeStats.jobId; this.modelBytes = modelSizeStats.modelBytes; this.peakModelBytes = modelSizeStats.peakModelBytes; + this.actualMemoryUsageBytes = modelSizeStats.actualMemoryUsageBytes; this.modelBytesExceeded = modelSizeStats.modelBytesExceeded; this.modelBytesMemoryLimit = modelSizeStats.modelBytesMemoryLimit; this.totalByFieldCount = modelSizeStats.totalByFieldCount; @@ -603,6 +629,12 @@ public Builder setPeakModelBytes(long peakModelBytes) { return this; } + public Builder setActualMemoryUsageBytes(long actualMemoryUsageBytes) + { + this.actualMemoryUsageBytes = actualMemoryUsageBytes; + return this; + } + public Builder setModelBytesExceeded(long modelBytesExceeded) { this.modelBytesExceeded = modelBytesExceeded; return this; @@ -700,6 +732,7 @@ public ModelSizeStats build() { jobId, modelBytes, peakModelBytes, + actualMemoryUsageBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java index 56cd1948b021f..cc685e29e3d04 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java @@ -1590,6 +1590,13 @@ void calculateEstablishedMemoryUsage( handler.accept((storedPeak != null) ? storedPeak : latestModelSizeStats.getModelBytes()); return; } + case ACTUAL_MEMORY_USAGE_BYTES -> { + Long storedActualMemoryUsageBytes = latestModelSizeStats.getActualMemoryUsageBytes(); + handler.accept( + (storedActualMemoryUsageBytes != null) ? storedActualMemoryUsageBytes : latestModelSizeStats.getModelBytes() + ); + return; + } } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index d003578158f48..ba10d0c8090f6 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -1076,6 +1076,8 @@ public ByteSizeValue getOpenProcessMemoryUsage() { case MODEL_MEMORY_LIMIT -> Optional.ofNullable(modelSizeStats.getModelBytesMemoryLimit()).orElse(0L); case CURRENT_MODEL_BYTES -> modelSizeStats.getModelBytes(); case PEAK_MODEL_BYTES -> Optional.ofNullable(modelSizeStats.getPeakModelBytes()).orElse(modelSizeStats.getModelBytes()); + case ACTUAL_MEMORY_USAGE_BYTES -> Optional.ofNullable(modelSizeStats.getActualMemoryUsageBytes()) + .orElse(modelSizeStats.getModelBytes()); }; memoryUsedBytes += Job.PROCESS_MEMORY_OVERHEAD.getBytes(); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java index fcf2f2e32f16b..03aeb18ebd3b1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java @@ -834,10 +834,12 @@ public void testGetOpenProcessMemoryUsage() { long modelMemoryLimitBytes = ByteSizeValue.ofMb(randomIntBetween(10, 1000)).getBytes(); long peakModelBytes = randomLongBetween(100000, modelMemoryLimitBytes - 1); long modelBytes = randomLongBetween(1, peakModelBytes - 1); + long actualMemoryUsageBytes = randomLongBetween(262144, peakModelBytes - 1); AssignmentMemoryBasis assignmentMemoryBasis = randomFrom(AssignmentMemoryBasis.values()); modelSizeStats = new ModelSizeStats.Builder("foo").setModelBytesMemoryLimit(modelMemoryLimitBytes) .setPeakModelBytes(peakModelBytes) .setModelBytes(modelBytes) + .setActualMemoryUsageBytes(actualMemoryUsageBytes) .setAssignmentMemoryBasis(assignmentMemoryBasis) .build(); when(autodetectCommunicator.getModelSizeStats()).thenReturn(modelSizeStats); @@ -850,6 +852,7 @@ public void testGetOpenProcessMemoryUsage() { case MODEL_MEMORY_LIMIT -> modelMemoryLimitBytes; case CURRENT_MODEL_BYTES -> modelBytes; case PEAK_MODEL_BYTES -> peakModelBytes; + case ACTUAL_MEMORY_USAGE_BYTES -> actualMemoryUsageBytes; }; assertThat(manager.getOpenProcessMemoryUsage(), equalTo(ByteSizeValue.ofBytes(expectedSizeBytes))); } From 87cef320123473fe9373982b59e89458bf6d6929 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 4 Apr 2025 15:15:52 +1300 Subject: [PATCH 2/7] Update docs/changelog/126256.yaml --- docs/changelog/126256.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 docs/changelog/126256.yaml diff --git a/docs/changelog/126256.yaml b/docs/changelog/126256.yaml new file mode 100644 index 0000000000000..51f4ea756636f --- /dev/null +++ b/docs/changelog/126256.yaml @@ -0,0 +1,5 @@ +pr: 126256 +summary: Handle new `actual_memory_usage_bytes` field in model size stats +area: Machine Learning +type: enhancement +issues: [] From e2eec0e89fea315e3765c1d1a50d411b880da4d1 Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Fri, 4 Apr 2025 15:27:54 +1300 Subject: [PATCH 3/7] Update from main --- docs/changelog/121041.yaml | 5 + .../mapping-reference/semantic-text.md | 160 +++-- .../ingest/common/CommunityIdProcessor.java | 167 +++--- .../CommunityIdProcessorFactoryTests.java | 67 ++- .../common/CommunityIdProcessorTests.java | 147 ++--- .../metadata/InferenceFieldMetadata.java | 52 +- .../inference/ChunkInferenceInput.java | 25 + .../inference/ChunkingSettings.java | 4 + .../inference/InferenceService.java | 16 +- .../cluster/metadata/IndexMetadataTests.java | 3 +- .../metadata/InferenceFieldMetadataTests.java | 56 +- ...appingLookupInferenceFieldMapperTests.java | 9 +- .../mock/AbstractTestInferenceService.java | 37 ++ .../TestDenseInferenceServiceExtension.java | 27 +- .../mock/TestRerankingServiceExtension.java | 3 +- .../TestSparseInferenceServiceExtension.java | 34 +- ...stStreamingCompletionServiceExtension.java | 3 +- .../inference/src/main/java/module-info.java | 1 + .../xpack/inference/InferenceFeatures.java | 4 +- .../ShardBulkInferenceActionFilter.java | 18 +- .../chunking/ChunkingSettingsBuilder.java | 26 +- .../chunking/EmbeddingRequestChunker.java | 43 +- .../SentenceBoundaryChunkingSettings.java | 18 + .../WordBoundaryChunkingSettings.java | 26 + ...baCloudSearchEmbeddingsRequestManager.java | 2 +- ...libabaCloudSearchSparseRequestManager.java | 2 +- ...AmazonBedrockEmbeddingsRequestManager.java | 2 +- ...AzureAiStudioEmbeddingsRequestManager.java | 2 +- .../AzureOpenAiEmbeddingsRequestManager.java | 2 +- .../CohereEmbeddingsRequestManager.java | 2 +- ...ServiceSparseEmbeddingsRequestManager.java | 2 +- .../external/http/sender/EmbeddingsInput.java | 25 +- ...oogleAiStudioEmbeddingsRequestManager.java | 2 +- ...oogleVertexAiEmbeddingsRequestManager.java | 2 +- .../sender/HuggingFaceRequestManager.java | 2 +- .../IbmWatsonxEmbeddingsRequestManager.java | 2 +- .../JinaAIEmbeddingsRequestManager.java | 2 +- .../MistralEmbeddingsRequestManager.java | 2 +- .../http/sender/TruncatingRequestManager.java | 2 +- .../inference/mapper/SemanticTextField.java | 47 +- .../mapper/SemanticTextFieldMapper.java | 36 +- .../inference/services/SenderService.java | 13 +- .../AlibabaCloudSearchService.java | 2 +- .../amazonbedrock/AmazonBedrockService.java | 2 +- .../azureaistudio/AzureAiStudioService.java | 2 +- .../azureopenai/AzureOpenAiService.java | 2 +- .../services/cohere/CohereService.java | 2 +- .../elastic/ElasticInferenceService.java | 2 +- .../ElasticsearchInternalService.java | 14 +- .../googleaistudio/GoogleAiStudioService.java | 8 +- .../googlevertexai/GoogleVertexAiService.java | 2 +- .../huggingface/HuggingFaceService.java | 2 +- .../elser/HuggingFaceElserService.java | 7 +- .../ibmwatsonx/IbmWatsonxService.java | 4 +- .../services/jinaai/JinaAIService.java | 2 +- .../services/mistral/MistralService.java | 2 +- .../services/openai/OpenAiService.java | 2 +- .../services/voyageai/VoyageAIService.java | 2 +- .../action/VoyageAIActionCreator.java | 2 +- ...KnnVectorQueryRewriteInterceptorTests.java | 4 +- ...nticMatchQueryRewriteInterceptorTests.java | 2 +- ...rseVectorQueryRewriteInterceptorTests.java | 4 +- .../ShardBulkInferenceActionFilterTests.java | 28 +- .../ChunkingSettingsBuilderTests.java | 8 +- .../EmbeddingRequestChunkerTests.java | 230 +++++--- ...ingleInputSenderExecutableActionTests.java | 3 +- .../http/sender/HttpRequestSenderTests.java | 3 +- .../http/sender/RequestTaskTests.java | 11 +- ...cInferenceMetadataFieldsRecoveryTests.java | 27 +- .../mapper/SemanticTextFieldMapperTests.java | 114 +++- .../mapper/SemanticTextFieldTests.java | 36 +- .../queries/SemanticQueryBuilderTests.java | 2 +- .../AlibabaCloudSearchServiceTests.java | 3 +- ...ibabaCloudSearchCompletionActionTests.java | 2 +- .../AmazonBedrockServiceTests.java | 3 +- .../AmazonBedrockActionCreatorTests.java | 6 +- .../AmazonBedrockMockRequestSender.java | 3 +- .../AmazonBedrockRequestSenderTests.java | 3 +- .../AzureAiStudioServiceTests.java | 3 +- .../AzureAiStudioActionAndCreatorTests.java | 2 +- .../azureopenai/AzureOpenAiServiceTests.java | 3 +- .../action/AzureOpenAiActionCreatorTests.java | 16 +- .../AzureOpenAiEmbeddingsActionTests.java | 31 +- .../services/cohere/CohereServiceTests.java | 5 +- .../action/CohereActionCreatorTests.java | 2 +- .../action/CohereEmbeddingsActionTests.java | 14 +- .../elastic/ElasticInferenceServiceTests.java | 5 +- ...ticInferenceServiceActionCreatorTests.java | 8 +- .../ElasticsearchInternalServiceTests.java | 17 +- .../GoogleAiStudioServiceTests.java | 11 +- .../GoogleAiStudioEmbeddingsActionTests.java | 8 +- .../GoogleVertexAiEmbeddingsActionTests.java | 6 +- .../GoogleVertexAiRerankActionTests.java | 6 +- .../HuggingFaceElserServiceTests.java | 3 +- .../huggingface/HuggingFaceServiceTests.java | 5 +- .../action/HuggingFaceActionCreatorTests.java | 12 +- .../action/HuggingFaceActionTests.java | 6 +- .../ibmwatsonx/IbmWatsonxServiceTests.java | 7 +- .../IbmWatsonxEmbeddingsActionTests.java | 8 +- .../services/jinaai/JinaAIServiceTests.java | 3 +- .../services/mistral/MistralServiceTests.java | 3 +- .../services/openai/OpenAiServiceTests.java | 3 +- .../action/OpenAiActionCreatorTests.java | 14 +- .../action/OpenAiEmbeddingsActionTests.java | 13 +- .../voyageai/VoyageAIServiceTests.java | 3 +- .../action/VoyageAIActionCreatorTests.java | 7 +- .../action/VoyageAIEmbeddingsActionTests.java | 27 +- ...5_semantic_text_field_mapping_chunking.yml | 523 +++++++++++++++++ ...mantic_text_field_mapping_chunking_bwc.yml | 550 ++++++++++++++++++ .../60_semantic_text_inference_update.yml | 28 +- 110 files changed, 2385 insertions(+), 618 deletions(-) create mode 100644 docs/changelog/121041.yaml create mode 100644 server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml create mode 100644 x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml diff --git a/docs/changelog/121041.yaml b/docs/changelog/121041.yaml new file mode 100644 index 0000000000000..44a51a966c0a1 --- /dev/null +++ b/docs/changelog/121041.yaml @@ -0,0 +1,5 @@ +pr: 121041 +summary: Support configurable chunking in `semantic_text` fields +area: Relevance +type: enhancement +issues: [] diff --git a/docs/reference/elasticsearch/mapping-reference/semantic-text.md b/docs/reference/elasticsearch/mapping-reference/semantic-text.md index abd44df2a139a..22a19c84e3cee 100644 --- a/docs/reference/elasticsearch/mapping-reference/semantic-text.md +++ b/docs/reference/elasticsearch/mapping-reference/semantic-text.md @@ -6,15 +6,32 @@ mapped_pages: # Semantic text field type [semantic-text] -The `semantic_text` field type automatically generates embeddings for text content using an inference endpoint. Long passages are [automatically chunked](#auto-text-chunking) to smaller sections to enable the processing of larger corpuses of text. - -The `semantic_text` field type specifies an inference endpoint identifier that will be used to generate embeddings. You can create the inference endpoint by using the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). This field type and the [`semantic` query](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) type make it simpler to perform semantic search on your data. The `semantic_text` field type may also be queried with [match](/reference/query-languages/query-dsl/query-dsl-match-query.md), [sparse_vector](/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md) or [knn](/reference/query-languages/query-dsl/query-dsl-knn-query.md) queries. - -If you don’t specify an inference endpoint, the `inference_id` field defaults to `.elser-2-elasticsearch`, a preconfigured endpoint for the elasticsearch service. - -Using `semantic_text`, you won’t need to specify how to generate embeddings for your data, or how to index it. The {{infer}} endpoint automatically determines the embedding generation, indexing, and query to use. - -If you use the preconfigured `.elser-2-elasticsearch` endpoint, you can set up `semantic_text` with the following API request: +The `semantic_text` field type automatically generates embeddings for text +content using an inference endpoint. Long passages +are [automatically chunked](#auto-text-chunking) to smaller sections to enable +the processing of larger corpuses of text. + +The `semantic_text` field type specifies an inference endpoint identifier that +will be used to generate embeddings. You can create the inference endpoint by +using +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put). +This field type and the [ +`semantic` query](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) +type make it simpler to perform semantic search on your data. The +`semantic_text` field type may also be queried +with [match](/reference/query-languages/query-dsl/query-dsl-match-query.md), [sparse_vector](/reference/query-languages/query-dsl/query-dsl-sparse-vector-query.md) +or [knn](/reference/query-languages/query-dsl/query-dsl-knn-query.md) queries. + +If you don’t specify an inference endpoint, the `inference_id` field defaults to +`.elser-2-elasticsearch`, a preconfigured endpoint for the elasticsearch +service. + +Using `semantic_text`, you won’t need to specify how to generate embeddings for +your data, or how to index it. The {{infer}} endpoint automatically determines +the embedding generation, indexing, and query to use. + +If you use the preconfigured `.elser-2-elasticsearch` endpoint, you can set up +`semantic_text` with the following API request: ```console PUT my-index-000001 @@ -29,7 +46,10 @@ PUT my-index-000001 } ``` -To use a custom {{infer}} endpoint instead of the default `.elser-2-elasticsearch`, you must [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) and specify its `inference_id` when setting up the `semantic_text` field type. +To use a custom {{infer}} endpoint instead of the default +`.elser-2-elasticsearch`, you +must [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +and specify its `inference_id` when setting up the `semantic_text` field type. ```console PUT my-index-000002 @@ -47,8 +67,12 @@ PUT my-index-000002 1. The `inference_id` of the {{infer}} endpoint to use to generate embeddings. - -The recommended way to use `semantic_text` is by having dedicated {{infer}} endpoints for ingestion and search. This ensures that search speed remains unaffected by ingestion workloads, and vice versa. After creating dedicated {{infer}} endpoints for both, you can reference them using the `inference_id` and `search_inference_id` parameters when setting up the index mapping for an index that uses the `semantic_text` field. +The recommended way to use `semantic_text` is by having dedicated {{infer}} +endpoints for ingestion and search. This ensures that search speed remains +unaffected by ingestion workloads, and vice versa. After creating dedicated +{{infer}} endpoints for both, you can reference them using the `inference_id` +and `search_inference_id` parameters when setting up the index mapping for an +index that uses the `semantic_text` field. ```console PUT my-index-000003 @@ -65,40 +89,71 @@ PUT my-index-000003 } ``` - ## Parameters for `semantic_text` fields [semantic-text-params] `inference_id` -: (Required, string) {{infer-cap}} endpoint that will be used to generate embeddings for the field. By default, `.elser-2-elasticsearch` is used. This parameter cannot be updated. Use the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) to create the endpoint. If `search_inference_id` is specified, the {{infer}} endpoint will only be used at index time. +: (Required, string) {{infer-cap}} endpoint that will be used to generate +embeddings for the field. By default, `.elser-2-elasticsearch` is used. This +parameter cannot be updated. Use +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +to create the endpoint. If `search_inference_id` is specified, the {{infer}} +endpoint will only be used at index time. `search_inference_id` -: (Optional, string) {{infer-cap}} endpoint that will be used to generate embeddings at query time. You can update this parameter by using the [Update mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping). Use the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) to create the endpoint. If not specified, the {{infer}} endpoint defined by `inference_id` will be used at both index and query time. - +: (Optional, string) {{infer-cap}} endpoint that will be used to generate +embeddings at query time. You can update this parameter by using +the [Update mapping API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-indices-put-mapping). +Use +the [Create {{infer}} API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +to create the endpoint. If not specified, the {{infer}} endpoint defined by +`inference_id` will be used at both index and query time. + +`chunking_settings` +: (Optional, object) Sets chunking settings that will override the settings +configured by the `inference_id` endpoint. +See [chunking settings attributes](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-put) +in the {{infer}} API documentation for a complete list of available options. ## {{infer-cap}} endpoint validation [infer-endpoint-validation] -The `inference_id` will not be validated when the mapping is created, but when documents are ingested into the index. When the first document is indexed, the `inference_id` will be used to generate underlying indexing structures for the field. +The `inference_id` will not be validated when the mapping is created, but when +documents are ingested into the index. When the first document is indexed, the +`inference_id` will be used to generate underlying indexing structures for the +field. ::::{warning} -Removing an {{infer}} endpoint will cause ingestion of documents and semantic queries to fail on indices that define `semantic_text` fields with that {{infer}} endpoint as their `inference_id`. Trying to [delete an {{infer}} endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-delete) that is used on a `semantic_text` field will result in an error. +Removing an {{infer}} endpoint will cause ingestion of documents and semantic +queries to fail on indices that define `semantic_text` fields with that +{{infer}} endpoint as their `inference_id`. Trying +to [delete an {{infer}} endpoint](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-inference-delete) +that is used on a `semantic_text` field will result in an error. :::: - - ## Text chunking [auto-text-chunking] -{{infer-cap}} endpoints have a limit on the amount of text they can process. To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called *chunks*. +{{infer-cap}} endpoints have a limit on the amount of text they can process. To +allow for large amounts of text to be used in semantic search, `semantic_text` +automatically generates smaller passages if needed, called *chunks*. -Each chunk refers to a passage of the text and the corresponding embedding generated from it. When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score. +Each chunk refers to a passage of the text and the corresponding embedding +generated from it. When querying, the individual passages will be automatically +searched for each document, and the most relevant passage will be used to +compute a score. -For more details on chunking and how to configure chunking settings, see [Configuring chunking](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-inference) in the Inference API documentation. - -Refer to [this tutorial](docs-content://solutions/search/semantic-search/semantic-search-semantic-text.md) to learn more about semantic search using `semantic_text` and the `semantic` query. +For more details on chunking and how to configure chunking settings, +see [Configuring chunking](https://www.elastic.co/docs/api/doc/elasticsearch/group/endpoint-inference) +in the Inference API documentation. +Refer +to [this tutorial](docs-content://solutions/search/semantic-search/semantic-search-semantic-text.md) +to learn more about semantic search using `semantic_text` and the `semantic` +query. ## Extracting Relevant Fragments from Semantic Text [semantic-text-highlighting] -You can extract the most relevant fragments from a semantic text field by using the [highlight parameter](/reference/elasticsearch/rest-apis/highlighting.md) in the [Search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). +You can extract the most relevant fragments from a semantic text field by using +the [highlight parameter](/reference/elasticsearch/rest-apis/highlighting.md) in +the [Search API](https://www.elastic.co/docs/api/doc/elasticsearch/operation/operation-search). ```console POST test-index/_search @@ -120,10 +175,13 @@ POST test-index/_search ``` 1. Specifies the maximum number of fragments to return. -2. Sorts highlighted fragments by score when set to `score`. By default, fragments will be output in the order they appear in the field (order: none). - +2. Sorts highlighted fragments by score when set to `score`. By default, + fragments will be output in the order they appear in the field (order: none). -Highlighting is supported on fields other than semantic_text. However, if you want to restrict highlighting to the semantic highlighter and return no fragments when the field is not of type semantic_text, you can explicitly enforce the `semantic` highlighter in the query: +Highlighting is supported on fields other than semantic_text. However, if you +want to restrict highlighting to the semantic highlighter and return no +fragments when the field is not of type semantic_text, you can explicitly +enforce the `semantic` highlighter in the query: ```console PUT test-index @@ -147,23 +205,42 @@ PUT test-index 1. Ensures that highlighting is applied exclusively to semantic_text fields. - - ## Customizing `semantic_text` indexing [custom-indexing] -`semantic_text` uses defaults for indexing data based on the {{infer}} endpoint specified. It enables you to quickstart your semantic search by providing automatic {{infer}} and a dedicated query so you don’t need to provide further details. - -In case you want to customize data indexing, use the [`sparse_vector`](/reference/elasticsearch/mapping-reference/sparse-vector.md) or [`dense_vector`](/reference/elasticsearch/mapping-reference/dense-vector.md) field types and create an ingest pipeline with an [{{infer}} processor](/reference/enrich-processor/inference-processor.md) to generate the embeddings. [This tutorial](docs-content://solutions/search/semantic-search/semantic-search-inference.md) walks you through the process. In these cases - when you use `sparse_vector` or `dense_vector` field types instead of the `semantic_text` field type to customize indexing - using the [`semantic_query`](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) is not supported for querying the field data. - +`semantic_text` uses defaults for indexing data based on the {{infer}} endpoint +specified. It enables you to quickstart your semantic search by providing +automatic {{infer}} and a dedicated query so you don’t need to provide further +details. + +In case you want to customize data indexing, use the [ +`sparse_vector`](/reference/elasticsearch/mapping-reference/sparse-vector.md) +or [`dense_vector`](/reference/elasticsearch/mapping-reference/dense-vector.md) +field types and create an ingest pipeline with +an [{{infer}} processor](/reference/enrich-processor/inference-processor.md) to +generate the +embeddings. [This tutorial](docs-content://solutions/search/semantic-search/semantic-search-inference.md) +walks you through the process. In these cases - when you use `sparse_vector` or +`dense_vector` field types instead of the `semantic_text` field type to +customize indexing - using the [ +`semantic_query`](/reference/query-languages/query-dsl/query-dsl-semantic-query.md) +is not supported for querying the field data. ## Updates to `semantic_text` fields [update-script] -Updates that use scripts are not supported for an index contains a `semantic_text` field. Even if the script targets non-`semantic_text` fields, the update will fail when the index contains a `semantic_text` field. - +Updates that use scripts are not supported for an index contains a +`semantic_text` field. Even if the script targets non-`semantic_text` fields, +the update will fail when the index contains a `semantic_text` field. ## `copy_to` and multi-fields support [copy-to-support] -The semantic_text field type can serve as the target of [copy_to fields](/reference/elasticsearch/mapping-reference/copy-to.md), be part of a [multi-field](/reference/elasticsearch/mapping-reference/multi-fields.md) structure, or contain [multi-fields](/reference/elasticsearch/mapping-reference/multi-fields.md) internally. This means you can use a single field to collect the values of other fields for semantic search. +The semantic_text field type can serve as the target +of [copy_to fields](/reference/elasticsearch/mapping-reference/copy-to.md), be +part of +a [multi-field](/reference/elasticsearch/mapping-reference/multi-fields.md) +structure, or +contain [multi-fields](/reference/elasticsearch/mapping-reference/multi-fields.md) +internally. This means you can use a single field to collect the values of other +fields for semantic search. For example, the following mapping: @@ -206,11 +283,12 @@ PUT test-index } ``` - ## Limitations [limitations] `semantic_text` field types have the following limitations: -* `semantic_text` fields are not currently supported as elements of [nested fields](/reference/elasticsearch/mapping-reference/nested.md). -* `semantic_text` fields can’t currently be set as part of [Dynamic templates](docs-content://manage-data/data-store/mapping/dynamic-templates.md). +* `semantic_text` fields are not currently supported as elements + of [nested fields](/reference/elasticsearch/mapping-reference/nested.md). +* `semantic_text` fields can’t currently be set as part + of [Dynamic templates](docs-content://manage-data/data-store/mapping/dynamic-templates.md). diff --git a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/CommunityIdProcessor.java b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/CommunityIdProcessor.java index cda21f34b575f..f37fcf46979dc 100644 --- a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/CommunityIdProcessor.java +++ b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/CommunityIdProcessor.java @@ -23,12 +23,12 @@ import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Base64; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.function.Supplier; +import static java.util.Map.entry; import static org.elasticsearch.ingest.ConfigurationUtils.newConfigurationException; import static org.elasticsearch.ingest.ConfigurationUtils.readBooleanProperty; @@ -131,26 +131,26 @@ public boolean getIgnoreMissing() { } @Override - public IngestDocument execute(IngestDocument ingestDocument) throws Exception { - String sourceIp = ingestDocument.getFieldValue(sourceIpField, String.class, ignoreMissing); - String destinationIp = ingestDocument.getFieldValue(destinationIpField, String.class, ignoreMissing); - Object ianaNumber = ingestDocument.getFieldValue(ianaNumberField, Object.class, true); - Supplier transport = () -> ingestDocument.getFieldValue(transportField, Object.class, ignoreMissing); - Supplier sourcePort = () -> ingestDocument.getFieldValue(sourcePortField, Object.class, ignoreMissing); - Supplier destinationPort = () -> ingestDocument.getFieldValue(destinationPortField, Object.class, ignoreMissing); - Object icmpType = ingestDocument.getFieldValue(icmpTypeField, Object.class, true); - Object icmpCode = ingestDocument.getFieldValue(icmpCodeField, Object.class, true); + public IngestDocument execute(IngestDocument document) throws Exception { + String sourceIp = document.getFieldValue(sourceIpField, String.class, ignoreMissing); + String destinationIp = document.getFieldValue(destinationIpField, String.class, ignoreMissing); + Object ianaNumber = document.getFieldValue(ianaNumberField, Object.class, true); + Supplier transport = () -> document.getFieldValue(transportField, Object.class, ignoreMissing); + Supplier sourcePort = () -> document.getFieldValue(sourcePortField, Object.class, ignoreMissing); + Supplier destinationPort = () -> document.getFieldValue(destinationPortField, Object.class, ignoreMissing); + Object icmpType = document.getFieldValue(icmpTypeField, Object.class, true); + Object icmpCode = document.getFieldValue(icmpCodeField, Object.class, true); Flow flow = buildFlow(sourceIp, destinationIp, ianaNumber, transport, sourcePort, destinationPort, icmpType, icmpCode); if (flow == null) { if (ignoreMissing) { - return ingestDocument; + return document; } else { throw new IllegalArgumentException("unable to construct flow from document"); } } - ingestDocument.setFieldValue(targetField, flow.toCommunityId(seed)); - return ingestDocument; + document.setFieldValue(targetField, flow.toCommunityId(seed)); + return document; } public static String apply( @@ -164,7 +164,6 @@ public static String apply( Object icmpCode, int seed ) { - Flow flow = buildFlow( sourceIpAddrString, destIpAddrString, @@ -256,6 +255,7 @@ public String getType() { /** * Converts an integer in the range of an unsigned 16-bit integer to a big-endian byte pair */ + // visible for testing static byte[] toUint16(int num) { if (num < 0 || num > 65535) { throw new IllegalStateException("number [" + num + "] must be a value between 0 and 65535"); @@ -266,7 +266,7 @@ static byte[] toUint16(int num) { /** * Attempts to coerce an object to an integer */ - static int parseIntFromObjectOrString(Object o, String fieldName) { + private static int parseIntFromObjectOrString(Object o, String fieldName) { if (o == null) { return 0; } else if (o instanceof Number number) { @@ -296,28 +296,28 @@ public static final class Factory implements Processor.Factory { @Override public CommunityIdProcessor create( Map registry, - String processorTag, + String tag, String description, Map config, ProjectId projectId ) throws Exception { - String sourceIpField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "source_ip", DEFAULT_SOURCE_IP); - String sourcePortField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "source_port", DEFAULT_SOURCE_PORT); - String destIpField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "destination_ip", DEFAULT_DEST_IP); - String destPortField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "destination_port", DEFAULT_DEST_PORT); - String ianaNumberField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "iana_number", DEFAULT_IANA_NUMBER); - String transportField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "transport", DEFAULT_TRANSPORT); - String icmpTypeField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "icmp_type", DEFAULT_ICMP_TYPE); - String icmpCodeField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "icmp_code", DEFAULT_ICMP_CODE); - String targetField = ConfigurationUtils.readStringProperty(TYPE, processorTag, config, "target_field", DEFAULT_TARGET); - int seedInt = ConfigurationUtils.readIntProperty(TYPE, processorTag, config, "seed", 0); + String sourceIpField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "source_ip", DEFAULT_SOURCE_IP); + String sourcePortField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "source_port", DEFAULT_SOURCE_PORT); + String destIpField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "destination_ip", DEFAULT_DEST_IP); + String destPortField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "destination_port", DEFAULT_DEST_PORT); + String ianaNumberField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "iana_number", DEFAULT_IANA_NUMBER); + String transportField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "transport", DEFAULT_TRANSPORT); + String icmpTypeField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "icmp_type", DEFAULT_ICMP_TYPE); + String icmpCodeField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "icmp_code", DEFAULT_ICMP_CODE); + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, "target_field", DEFAULT_TARGET); + int seedInt = ConfigurationUtils.readIntProperty(TYPE, tag, config, "seed", 0); if (seedInt < 0 || seedInt > 65535) { - throw newConfigurationException(TYPE, processorTag, "seed", "must be a value between 0 and 65535"); + throw newConfigurationException(TYPE, tag, "seed", "must be a value between 0 and 65535"); } - boolean ignoreMissing = readBooleanProperty(TYPE, processorTag, config, "ignore_missing", true); + boolean ignoreMissing = readBooleanProperty(TYPE, tag, config, "ignore_missing", true); return new CommunityIdProcessor( - processorTag, + tag, description, sourceIpField, sourcePortField, @@ -335,9 +335,13 @@ public CommunityIdProcessor create( } /** - * Represents flow data per https://github.com/corelight/community-id-spec + * Represents flow data per the Community ID spec. */ - public static final class Flow { + private static final class Flow { + + private Flow() { + // this is only constructable from inside this file + } private static final List TRANSPORTS_WITH_PORTS = List.of( Transport.Type.Tcp, @@ -357,7 +361,7 @@ public static final class Flow { /** * @return true iff the source address/port is numerically less than the destination address/port as described - * at https://github.com/corelight/community-id-spec + * in the Community ID spec. */ boolean isOrdered() { int result = new BigInteger(1, source.getAddress()).compareTo(new BigInteger(1, destination.getAddress())); @@ -401,7 +405,7 @@ String toCommunityId(byte[] seed) { } } - static class Transport { + static final class Transport { public enum Type { Unknown(-1), Icmp(1), @@ -417,22 +421,19 @@ public enum Type { private final int transportNumber; - private static final Map TRANSPORT_NAMES; - - static { - TRANSPORT_NAMES = new HashMap<>(); - TRANSPORT_NAMES.put("icmp", Icmp); - TRANSPORT_NAMES.put("igmp", Igmp); - TRANSPORT_NAMES.put("tcp", Tcp); - TRANSPORT_NAMES.put("udp", Udp); - TRANSPORT_NAMES.put("gre", Gre); - TRANSPORT_NAMES.put("ipv6-icmp", IcmpIpV6); - TRANSPORT_NAMES.put("icmpv6", IcmpIpV6); - TRANSPORT_NAMES.put("eigrp", Eigrp); - TRANSPORT_NAMES.put("ospf", Ospf); - TRANSPORT_NAMES.put("pim", Pim); - TRANSPORT_NAMES.put("sctp", Sctp); - } + private static final Map TRANSPORT_NAMES = Map.ofEntries( + entry("icmp", Icmp), + entry("igmp", Igmp), + entry("tcp", Tcp), + entry("udp", Udp), + entry("gre", Gre), + entry("ipv6-icmp", IcmpIpV6), + entry("icmpv6", IcmpIpV6), + entry("eigrp", Eigrp), + entry("ospf", Ospf), + entry("pim", Pim), + entry("sctp", Sctp) + ); Type(int transportNumber) { this.transportNumber = transportNumber; @@ -443,15 +444,15 @@ public int getTransportNumber() { } } - private Type type; - private int transportNumber; + private final Type type; + private final int transportNumber; - Transport(int transportNumber, Type type) { // Change constructor to public + private Transport(int transportNumber, Type type) { this.transportNumber = transportNumber; this.type = type; } - Transport(Type type) { // Change constructor to public + private Transport(Type type) { this.transportNumber = type.getTransportNumber(); this.type = type; } @@ -464,7 +465,8 @@ public int getTransportNumber() { return transportNumber; } - public static Transport fromNumber(int transportNumber) { + // visible for testing + static Transport fromNumber(int transportNumber) { if (transportNumber < 0 || transportNumber >= 255) { // transport numbers range https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml throw new IllegalArgumentException("invalid transport protocol number [" + transportNumber + "]"); @@ -487,7 +489,7 @@ public static Transport fromNumber(int transportNumber) { return new Transport(transportNumber, type); } - public static Transport fromObject(Object o) { + private static Transport fromObject(Object o) { if (o instanceof Number number) { return fromNumber(number.intValue()); } else if (o instanceof String protocolStr) { @@ -537,34 +539,31 @@ public enum IcmpType { V6HomeAddressDiscoveryRequest(144), V6HomeAddressDiscoveryResponse(145); - private static final Map ICMP_V4_CODE_EQUIVALENTS; - private static final Map ICMP_V6_CODE_EQUIVALENTS; - - static { - ICMP_V4_CODE_EQUIVALENTS = new HashMap<>(); - ICMP_V4_CODE_EQUIVALENTS.put(EchoRequest.getType(), EchoReply.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(EchoReply.getType(), EchoRequest.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(TimestampRequest.getType(), TimestampReply.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(TimestampReply.getType(), TimestampRequest.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(InfoRequest.getType(), InfoReply.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(RouterSolicitation.getType(), RouterAdvertisement.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(RouterAdvertisement.getType(), RouterSolicitation.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(AddressMaskRequest.getType(), AddressMaskReply.getType()); - ICMP_V4_CODE_EQUIVALENTS.put(AddressMaskReply.getType(), AddressMaskRequest.getType()); - - ICMP_V6_CODE_EQUIVALENTS = new HashMap<>(); - ICMP_V6_CODE_EQUIVALENTS.put(V6EchoRequest.getType(), V6EchoReply.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6EchoReply.getType(), V6EchoRequest.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6RouterSolicitation.getType(), V6RouterAdvertisement.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6RouterAdvertisement.getType(), V6RouterSolicitation.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6NeighborAdvertisement.getType(), V6NeighborSolicitation.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6NeighborSolicitation.getType(), V6NeighborAdvertisement.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6MLDv1MulticastListenerQueryMessage.getType(), V6MLDv1MulticastListenerReportMessage.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6WhoAreYouRequest.getType(), V6WhoAreYouReply.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6WhoAreYouReply.getType(), V6WhoAreYouRequest.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6HomeAddressDiscoveryRequest.getType(), V6HomeAddressDiscoveryResponse.getType()); - ICMP_V6_CODE_EQUIVALENTS.put(V6HomeAddressDiscoveryResponse.getType(), V6HomeAddressDiscoveryRequest.getType()); - } + private static final Map ICMP_V4_CODE_EQUIVALENTS = Map.ofEntries( + entry(EchoRequest.getType(), EchoReply.getType()), + entry(EchoReply.getType(), EchoRequest.getType()), + entry(TimestampRequest.getType(), TimestampReply.getType()), + entry(TimestampReply.getType(), TimestampRequest.getType()), + entry(InfoRequest.getType(), InfoReply.getType()), + entry(RouterSolicitation.getType(), RouterAdvertisement.getType()), + entry(RouterAdvertisement.getType(), RouterSolicitation.getType()), + entry(AddressMaskRequest.getType(), AddressMaskReply.getType()), + entry(AddressMaskReply.getType(), AddressMaskRequest.getType()) + ); + + private static final Map ICMP_V6_CODE_EQUIVALENTS = Map.ofEntries( + entry(V6EchoRequest.getType(), V6EchoReply.getType()), + entry(V6EchoReply.getType(), V6EchoRequest.getType()), + entry(V6RouterSolicitation.getType(), V6RouterAdvertisement.getType()), + entry(V6RouterAdvertisement.getType(), V6RouterSolicitation.getType()), + entry(V6NeighborAdvertisement.getType(), V6NeighborSolicitation.getType()), + entry(V6NeighborSolicitation.getType(), V6NeighborAdvertisement.getType()), + entry(V6MLDv1MulticastListenerQueryMessage.getType(), V6MLDv1MulticastListenerReportMessage.getType()), + entry(V6WhoAreYouRequest.getType(), V6WhoAreYouReply.getType()), + entry(V6WhoAreYouReply.getType(), V6WhoAreYouRequest.getType()), + entry(V6HomeAddressDiscoveryRequest.getType(), V6HomeAddressDiscoveryResponse.getType()), + entry(V6HomeAddressDiscoveryResponse.getType(), V6HomeAddressDiscoveryRequest.getType()) + ); private final int type; @@ -606,7 +605,7 @@ public static IcmpType fromNumber(int type) { }; } - public static Integer codeEquivalent(int icmpType, boolean isIpV6) { + private static Integer codeEquivalent(int icmpType, boolean isIpV6) { return isIpV6 ? ICMP_V6_CODE_EQUIVALENTS.get(icmpType) : ICMP_V4_CODE_EQUIVALENTS.get(icmpType); } } diff --git a/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorFactoryTests.java b/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorFactoryTests.java index 3426e9894c2b1..2f47d1cf5b6e8 100644 --- a/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorFactoryTests.java +++ b/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorFactoryTests.java @@ -64,61 +64,60 @@ public void testCreate() throws Exception { boolean ignoreMissing = randomBoolean(); config.put("ignore_missing", ignoreMissing); - String processorTag = randomAlphaOfLength(10); - CommunityIdProcessor communityIdProcessor = factory.create(null, processorTag, null, config, null); - assertThat(communityIdProcessor.getTag(), equalTo(processorTag)); - assertThat(communityIdProcessor.getSourceIpField(), equalTo(sourceIpField)); - assertThat(communityIdProcessor.getSourcePortField(), equalTo(sourcePortField)); - assertThat(communityIdProcessor.getDestinationIpField(), equalTo(destIpField)); - assertThat(communityIdProcessor.getDestinationPortField(), equalTo(destPortField)); - assertThat(communityIdProcessor.getIanaNumberField(), equalTo(ianaNumberField)); - assertThat(communityIdProcessor.getTransportField(), equalTo(transportField)); - assertThat(communityIdProcessor.getIcmpTypeField(), equalTo(icmpTypeField)); - assertThat(communityIdProcessor.getIcmpCodeField(), equalTo(icmpCodeField)); - assertThat(communityIdProcessor.getTargetField(), equalTo(targetField)); - assertThat(communityIdProcessor.getSeed(), equalTo(toUint16(seedInt))); - assertThat(communityIdProcessor.getIgnoreMissing(), equalTo(ignoreMissing)); + String tag = randomAlphaOfLength(10); + CommunityIdProcessor processor = factory.create(null, tag, null, config, null); + assertThat(processor.getTag(), equalTo(tag)); + assertThat(processor.getSourceIpField(), equalTo(sourceIpField)); + assertThat(processor.getSourcePortField(), equalTo(sourcePortField)); + assertThat(processor.getDestinationIpField(), equalTo(destIpField)); + assertThat(processor.getDestinationPortField(), equalTo(destPortField)); + assertThat(processor.getIanaNumberField(), equalTo(ianaNumberField)); + assertThat(processor.getTransportField(), equalTo(transportField)); + assertThat(processor.getIcmpTypeField(), equalTo(icmpTypeField)); + assertThat(processor.getIcmpCodeField(), equalTo(icmpCodeField)); + assertThat(processor.getTargetField(), equalTo(targetField)); + assertThat(processor.getSeed(), equalTo(toUint16(seedInt))); + assertThat(processor.getIgnoreMissing(), equalTo(ignoreMissing)); } public void testSeed() throws Exception { Map config = new HashMap<>(); - String processorTag = randomAlphaOfLength(10); + String tag = randomAlphaOfLength(10); // negative seeds are rejected int tooSmallSeed = randomIntBetween(Integer.MIN_VALUE, -1); config.put("seed", Integer.toString(tooSmallSeed)); - ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> factory.create(null, processorTag, null, config, null)); + ElasticsearchException e = expectThrows(ElasticsearchException.class, () -> factory.create(null, tag, null, config, null)); assertThat(e.getMessage(), containsString("must be a value between 0 and 65535")); // seeds >= 2^16 are rejected int tooBigSeed = randomIntBetween(65536, Integer.MAX_VALUE); config.put("seed", Integer.toString(tooBigSeed)); - e = expectThrows(ElasticsearchException.class, () -> factory.create(null, processorTag, null, config, null)); + e = expectThrows(ElasticsearchException.class, () -> factory.create(null, tag, null, config, null)); assertThat(e.getMessage(), containsString("must be a value between 0 and 65535")); // seeds between 0 and 2^16-1 are accepted int justRightSeed = randomIntBetween(0, 65535); byte[] expectedSeed = new byte[] { (byte) (justRightSeed >> 8), (byte) justRightSeed }; config.put("seed", Integer.toString(justRightSeed)); - CommunityIdProcessor communityIdProcessor = factory.create(null, processorTag, null, config, null); - assertThat(communityIdProcessor.getSeed(), equalTo(expectedSeed)); + CommunityIdProcessor processor = factory.create(null, tag, null, config, null); + assertThat(processor.getSeed(), equalTo(expectedSeed)); } public void testRequiredFields() throws Exception { - HashMap config = new HashMap<>(); - String processorTag = randomAlphaOfLength(10); - CommunityIdProcessor communityIdProcessor = factory.create(null, processorTag, null, config, null); - assertThat(communityIdProcessor.getTag(), equalTo(processorTag)); - assertThat(communityIdProcessor.getSourceIpField(), equalTo(DEFAULT_SOURCE_IP)); - assertThat(communityIdProcessor.getSourcePortField(), equalTo(DEFAULT_SOURCE_PORT)); - assertThat(communityIdProcessor.getDestinationIpField(), equalTo(DEFAULT_DEST_IP)); - assertThat(communityIdProcessor.getDestinationPortField(), equalTo(DEFAULT_DEST_PORT)); - assertThat(communityIdProcessor.getIanaNumberField(), equalTo(DEFAULT_IANA_NUMBER)); - assertThat(communityIdProcessor.getTransportField(), equalTo(DEFAULT_TRANSPORT)); - assertThat(communityIdProcessor.getIcmpTypeField(), equalTo(DEFAULT_ICMP_TYPE)); - assertThat(communityIdProcessor.getIcmpCodeField(), equalTo(DEFAULT_ICMP_CODE)); - assertThat(communityIdProcessor.getTargetField(), equalTo(DEFAULT_TARGET)); - assertThat(communityIdProcessor.getSeed(), equalTo(toUint16(0))); - assertThat(communityIdProcessor.getIgnoreMissing(), equalTo(true)); + String tag = randomAlphaOfLength(10); + CommunityIdProcessor processor = factory.create(null, tag, null, new HashMap<>(), null); + assertThat(processor.getTag(), equalTo(tag)); + assertThat(processor.getSourceIpField(), equalTo(DEFAULT_SOURCE_IP)); + assertThat(processor.getSourcePortField(), equalTo(DEFAULT_SOURCE_PORT)); + assertThat(processor.getDestinationIpField(), equalTo(DEFAULT_DEST_IP)); + assertThat(processor.getDestinationPortField(), equalTo(DEFAULT_DEST_PORT)); + assertThat(processor.getIanaNumberField(), equalTo(DEFAULT_IANA_NUMBER)); + assertThat(processor.getTransportField(), equalTo(DEFAULT_TRANSPORT)); + assertThat(processor.getIcmpTypeField(), equalTo(DEFAULT_ICMP_TYPE)); + assertThat(processor.getIcmpCodeField(), equalTo(DEFAULT_ICMP_CODE)); + assertThat(processor.getTargetField(), equalTo(DEFAULT_TARGET)); + assertThat(processor.getSeed(), equalTo(toUint16(0))); + assertThat(processor.getIgnoreMissing(), equalTo(true)); } } diff --git a/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorTests.java b/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorTests.java index dff9916093586..72771275743b0 100644 --- a/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorTests.java +++ b/modules/ingest-common/src/test/java/org/elasticsearch/ingest/common/CommunityIdProcessorTests.java @@ -9,6 +9,7 @@ package org.elasticsearch.ingest.common; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.ingest.TestIngestDocument; import org.elasticsearch.test.ESTestCase; @@ -32,7 +33,7 @@ public class CommunityIdProcessorTests extends ESTestCase { // NOTE: all test methods beginning with "testBeats" are intended to duplicate the unit tests for the Beats - // community_id processor (see Github link below) to ensure that this processor produces the same values. To + // community_id processor (see GitHub link below) to ensure that this processor produces the same values. To // the extent possible, these tests should be kept in sync. // // https://github.com/elastic/beats/blob/master/libbeat/processors/communityid/communityid_test.go @@ -40,77 +41,80 @@ public class CommunityIdProcessorTests extends ESTestCase { private Map event; @Before - public void setup() throws Exception { + public void setup() { event = buildEvent(); } private Map buildEvent() { - event = new HashMap<>(); var source = new HashMap(); source.put("ip", "128.232.110.120"); source.put("port", 34855); - event.put("source", source); + var destination = new HashMap(); destination.put("ip", "66.35.250.204"); destination.put("port", 80); - event.put("destination", destination); + var network = new HashMap(); network.put("transport", "TCP"); + + var event = new HashMap(); + event.put("source", source); + event.put("destination", destination); event.put("network", network); return event; } - public void testBeatsValid() throws Exception { - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + public void testBeatsValid() { + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); } - public void testBeatsSeed() throws Exception { - testCommunityIdProcessor(event, 123, "1:hTSGlFQnR58UCk+NfKRZzA32dPg="); + public void testBeatsSeed() { + testProcessor(event, 123, "1:hTSGlFQnR58UCk+NfKRZzA32dPg="); } - public void testBeatsInvalidSourceIp() throws Exception { + public void testBeatsInvalidSourceIp() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.put("ip", 2162716280L); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("field [source.ip] of type [java.lang.Long] cannot be cast to [java.lang.String]")); } - public void testBeatsInvalidSourcePort() throws Exception { + public void testBeatsInvalidSourcePort() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.put("port", 0); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("invalid source port")); } - public void testBeatsInvalidDestinationIp() throws Exception { + public void testBeatsInvalidDestinationIp() { @SuppressWarnings("unchecked") var destination = (Map) event.get("destination"); String invalidIp = "308.111.1.2.3"; destination.put("ip", invalidIp); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("'" + invalidIp + "' is not an IP string literal")); } - public void testBeatsInvalidDestinationPort() throws Exception { + public void testBeatsInvalidDestinationPort() { @SuppressWarnings("unchecked") var destination = (Map) event.get("destination"); destination.put("port", null); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); // slightly modified from the beats test in that this one reports the actual invalid value rather than '0' assertThat(e.getMessage(), containsString("invalid destination port [null]")); } - public void testBeatsUnknownProtocol() throws Exception { + public void testBeatsUnknownProtocol() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "xyz"); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("could not convert string [xyz] to transport protocol")); } - public void testBeatsIcmp() throws Exception { + public void testBeatsIcmp() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "icmp"); @@ -118,17 +122,17 @@ public void testBeatsIcmp() throws Exception { icmp.put("type", 3); icmp.put("code", 3); event.put("icmp", icmp); - testCommunityIdProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); + testProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); } - public void testBeatsIcmpWithoutTypeOrCode() throws Exception { + public void testBeatsIcmpWithoutTypeOrCode() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "icmp"); - testCommunityIdProcessor(event, "1:PAE85ZfR4SbNXl5URZwWYyDehwU="); + testProcessor(event, "1:PAE85ZfR4SbNXl5URZwWYyDehwU="); } - public void testBeatsIgmp() throws Exception { + public void testBeatsIgmp() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "igmp"); @@ -138,10 +142,10 @@ public void testBeatsIgmp() throws Exception { @SuppressWarnings("unchecked") var destination = (Map) event.get("destination"); destination.remove("port"); - testCommunityIdProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); + testProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); } - public void testBeatsProtocolNumberAsString() throws Exception { + public void testBeatsProtocolNumberAsString() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.remove("port"); @@ -151,10 +155,10 @@ public void testBeatsProtocolNumberAsString() throws Exception { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "2"); - testCommunityIdProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); + testProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); } - public void testBeatsProtocolNumber() throws Exception { + public void testBeatsProtocolNumber() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.remove("port"); @@ -164,18 +168,18 @@ public void testBeatsProtocolNumber() throws Exception { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", 2); - testCommunityIdProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); + testProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); } - public void testBeatsIanaNumberProtocolTCP() throws Exception { + public void testBeatsIanaNumberProtocolTCP() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.remove("transport"); network.put("iana_number", CommunityIdProcessor.Transport.Type.Tcp.getTransportNumber()); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); } - public void testBeatsIanaNumberProtocolIPv4() throws Exception { + public void testBeatsIanaNumberProtocolIPv4() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("iana_number", "4"); @@ -188,20 +192,20 @@ public void testBeatsIanaNumberProtocolIPv4() throws Exception { var destination = (Map) event.get("destination"); destination.put("ip", "10.1.2.3"); destination.remove("port"); - testCommunityIdProcessor(event, "1:KXQzmk3bdsvD6UXj7dvQ4bM6Zvw="); + testProcessor(event, "1:KXQzmk3bdsvD6UXj7dvQ4bM6Zvw="); } - public void testIpv6() throws Exception { + public void testIpv6() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.put("ip", "2001:0db8:85a3:0000:0000:8a2e:0370:7334"); @SuppressWarnings("unchecked") var destination = (Map) event.get("destination"); destination.put("ip", "2001:0:9d38:6ab8:1c48:3a1c:a95a:b1c2"); - testCommunityIdProcessor(event, "1:YC1+javPJ2LpK5xVyw1udfT83Qs="); + testProcessor(event, "1:YC1+javPJ2LpK5xVyw1udfT83Qs="); } - public void testIcmpWithCodeEquivalent() throws Exception { + public void testIcmpWithCodeEquivalent() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.put("transport", "icmp"); @@ -209,20 +213,20 @@ public void testIcmpWithCodeEquivalent() throws Exception { icmp.put("type", 10); icmp.put("code", 3); event.put("icmp", icmp); - testCommunityIdProcessor(event, "1:L8wnzpmRHIESLqLBy+zTqW3Pmqs="); + testProcessor(event, "1:L8wnzpmRHIESLqLBy+zTqW3Pmqs="); } - public void testStringAndNumber() throws Exception { + public void testStringAndNumber() { // iana event = buildEvent(); @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.remove("transport"); network.put("iana_number", CommunityIdProcessor.Transport.Type.Tcp.getTransportNumber()); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); network.put("iana_number", Integer.toString(CommunityIdProcessor.Transport.Type.Tcp.getTransportNumber())); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); // protocol number event = buildEvent(); @@ -235,30 +239,30 @@ public void testStringAndNumber() throws Exception { @SuppressWarnings("unchecked") var network2 = (Map) event.get("network"); network2.put("transport", 2); - testCommunityIdProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); + testProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); network2.put("transport", "2"); - testCommunityIdProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); + testProcessor(event, "1:D3t8Q1aFA6Ev0A/AO4i9PnU3AeI="); // source port event = buildEvent(); @SuppressWarnings("unchecked") var source2 = (Map) event.get("source"); source2.put("port", 34855); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); source2.put("port", "34855"); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); // dest port event = buildEvent(); @SuppressWarnings("unchecked") var dest2 = (Map) event.get("destination"); dest2.put("port", 80); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); dest2.put("port", "80"); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); // icmp type and code event = buildEvent(); @@ -269,89 +273,87 @@ public void testStringAndNumber() throws Exception { icmp.put("type", 3); icmp.put("code", 3); event.put("icmp", icmp); - testCommunityIdProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); + testProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); - icmp = new HashMap(); + icmp = new HashMap<>(); icmp.put("type", "3"); icmp.put("code", "3"); event.put("icmp", icmp); - testCommunityIdProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); + testProcessor(event, "1:KF3iG9XD24nhlSy4r1TcYIr5mfE="); } - public void testLongsForNumericValues() throws Exception { + public void testLongsForNumericValues() { event = buildEvent(); @SuppressWarnings("unchecked") var source2 = (Map) event.get("source"); source2.put("port", 34855L); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); } - public void testFloatsForNumericValues() throws Exception { + public void testFloatsForNumericValues() { event = buildEvent(); @SuppressWarnings("unchecked") var source2 = (Map) event.get("source"); source2.put("port", 34855.0); - testCommunityIdProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); + testProcessor(event, "1:LQU9qZlK+B5F3KDmev6m5PMibrg="); } - public void testInvalidPort() throws Exception { + public void testInvalidPort() { event = buildEvent(); @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.put("port", 0); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("invalid source port [0]")); event = buildEvent(); @SuppressWarnings("unchecked") var source2 = (Map) event.get("source"); source2.put("port", 65536); - e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("invalid source port [65536]")); event = buildEvent(); @SuppressWarnings("unchecked") var source3 = (Map) event.get("destination"); source3.put("port", 0); - e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("invalid destination port [0]")); event = buildEvent(); @SuppressWarnings("unchecked") var source4 = (Map) event.get("destination"); source4.put("port", 65536); - e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, null)); + e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, null)); assertThat(e.getMessage(), containsString("invalid destination port [65536]")); } - public void testIgnoreMissing() throws Exception { + public void testIgnoreMissing() { @SuppressWarnings("unchecked") var network = (Map) event.get("network"); network.remove("transport"); - testCommunityIdProcessor(event, 0, null, true); + testProcessor(event, 0, null, true); } - public void testIgnoreMissingIsFalse() throws Exception { + public void testIgnoreMissingIsFalse() { @SuppressWarnings("unchecked") var source = (Map) event.get("source"); source.remove("ip"); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testCommunityIdProcessor(event, 0, null, false)); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> testProcessor(event, 0, null, false)); assertThat(e.getMessage(), containsString("field [ip] not present as part of path [source.ip]")); } - private void testCommunityIdProcessor(Map source, String expectedHash) throws Exception { - testCommunityIdProcessor(source, 0, expectedHash); + private static void testProcessor(Map source, String expectedHash) { + testProcessor(source, 0, expectedHash); } - private void testCommunityIdProcessor(Map source, int seed, String expectedHash) throws Exception { - testCommunityIdProcessor(source, seed, expectedHash, false); + private static void testProcessor(Map source, int seed, String expectedHash) { + testProcessor(source, seed, expectedHash, false); } - private void testCommunityIdProcessor(Map source, int seed, String expectedHash, boolean ignoreMissing) - throws Exception { - + private static void testProcessor(Map source, int seed, String expectedHash, boolean ignoreMissing) { var processor = new CommunityIdProcessor( null, null, @@ -369,7 +371,12 @@ private void testCommunityIdProcessor(Map source, int seed, Stri ); IngestDocument input = TestIngestDocument.withDefaultVersion(source); - IngestDocument output = processor.execute(input); + IngestDocument output; + try { + output = processor.execute(input); + } catch (Exception e) { + throw ExceptionsHelper.convertToRuntime(e); + } String hash = output.getFieldValue(DEFAULT_TARGET, String.class, ignoreMissing); assertThat(hash, equalTo(expectedHash)); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java index 8917d5a9cbbb5..495403e963e45 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadata.java @@ -12,6 +12,7 @@ import org.elasticsearch.TransportVersions; import org.elasticsearch.cluster.Diff; import org.elasticsearch.cluster.SimpleDiffable; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xcontent.ToXContentFragment; @@ -22,8 +23,11 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG; + /** * Contains inference field data for fields. * As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need @@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable chunkingSettings; - public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) { - this(name, inferenceId, inferenceId, sourceFields); + public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map chunkingSettings) { + this(name, inferenceId, inferenceId, sourceFields, chunkingSettings); } - public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) { + public InferenceFieldMetadata( + String name, + String inferenceId, + String searchInferenceId, + String[] sourceFields, + Map chunkingSettings + ) { this.name = Objects.requireNonNull(name); this.inferenceId = Objects.requireNonNull(inferenceId); this.searchInferenceId = Objects.requireNonNull(searchInferenceId); this.sourceFields = Objects.requireNonNull(sourceFields); + this.chunkingSettings = chunkingSettings != null ? Map.copyOf(chunkingSettings) : null; } public InferenceFieldMetadata(StreamInput input) throws IOException { @@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException { this.searchInferenceId = this.inferenceId; } this.sourceFields = input.readStringArray(); + if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { + this.chunkingSettings = input.readGenericMap(); + } else { + this.chunkingSettings = null; + } } @Override @@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(searchInferenceId); } out.writeStringArray(sourceFields); + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG)) { + out.writeGenericMap(chunkingSettings); + } } @Override @@ -81,16 +102,22 @@ public boolean equals(Object o) { return Objects.equals(name, that.name) && Objects.equals(inferenceId, that.inferenceId) && Objects.equals(searchInferenceId, that.searchInferenceId) - && Arrays.equals(sourceFields, that.sourceFields); + && Arrays.equals(sourceFields, that.sourceFields) + && Objects.equals(chunkingSettings, that.chunkingSettings); } @Override public int hashCode() { - int result = Objects.hash(name, inferenceId, searchInferenceId); + int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings); result = 31 * result + Arrays.hashCode(sourceFields); return result; } + @Override + public String toString() { + return Strings.toString(this); + } + public String getName() { return name; } @@ -107,6 +134,10 @@ public String[] getSourceFields() { return sourceFields; } + public Map getChunkingSettings() { + return chunkingSettings; + } + public static Diff readDiffFrom(StreamInput in) throws IOException { return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in); } @@ -119,6 +150,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId); } builder.array(SOURCE_FIELDS_FIELD, sourceFields); + if (chunkingSettings != null) { + builder.startObject(CHUNKING_SETTINGS_FIELD); + builder.mapContents(chunkingSettings); + builder.endObject(); + } return builder.endObject(); } @@ -131,6 +167,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws String currentFieldName = null; String inferenceId = null; String searchInferenceId = null; + Map chunkingSettings = null; List inputFields = new ArrayList<>(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -151,6 +188,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws } } } + } else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) { + chunkingSettings = parser.map(); } else { parser.skipChildren(); } @@ -159,7 +198,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws name, inferenceId, searchInferenceId == null ? inferenceId : searchInferenceId, - inputFields.toArray(String[]::new) + inputFields.toArray(String[]::new), + chunkingSettings ); } } diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java new file mode 100644 index 0000000000000..8e25e0e55f08c --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/ChunkInferenceInput.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.core.Nullable; + +import java.util.List; + +public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) { + + public ChunkInferenceInput(String input) { + this(input, null); + } + + public static List inputs(List chunkInferenceInputs) { + return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList(); + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java index 2e9072626b0a8..34b3e5a6d58ee 100644 --- a/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java +++ b/server/src/main/java/org/elasticsearch/inference/ChunkingSettings.java @@ -12,6 +12,10 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.xcontent.ToXContentObject; +import java.util.Map; + public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable { ChunkingStrategy getChunkingStrategy(); + + Map asMap(); } diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 309db20083ece..f36642ab8d627 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -133,18 +133,18 @@ void unifiedCompletionInfer( /** * Chunk long text. * - * @param model The model - * @param query Inference query, mainly for re-ranking - * @param input Inference input - * @param taskSettings Settings in the request to override the model's defaults - * @param inputType For search, ingest etc - * @param timeout The timeout for the request - * @param listener Chunked Inference result listener + * @param model The model + * @param query Inference query, mainly for re-ranking + * @param input Inference input + * @param taskSettings Settings in the request to override the model's defaults + * @param inputType For search, ingest etc + * @param timeout The timeout for the request + * @param listener Chunked Inference result listener */ void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java index 3dde53956bb2a..39da67a81bec0 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/IndexMetadataTests.java @@ -727,7 +727,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name) name, randomIdentifier(), randomIdentifier(), - randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new) + randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new), + InferenceFieldMetadataTests.generateRandomChunkingSettings() ); } diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java index 2d5805696320d..f0c61b68226e1 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/InferenceFieldMetadataTests.java @@ -15,8 +15,10 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Map; import java.util.function.Predicate; +import static org.elasticsearch.cluster.metadata.InferenceFieldMetadata.CHUNKING_SETTINGS_FIELD; import static org.hamcrest.Matchers.equalTo; public class InferenceFieldMetadataTests extends AbstractXContentTestCase { @@ -37,11 +39,6 @@ protected InferenceFieldMetadata createTestInstance() { return createTestItem(); } - @Override - protected Predicate getRandomFieldsExcludeFilter() { - return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field - } - @Override protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException { if (parser.nextToken() == XContentParser.Token.START_OBJECT) { @@ -58,18 +55,57 @@ protected boolean supportsUnknownFields() { return true; } + @Override + protected Predicate getRandomFieldsExcludeFilter() { + // do not add elements at the top-level as any element at this level is parsed as a new inference field, + // and do not add additional elements to chunking maps as they will fail parsing with extra data + return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD); + } + private static InferenceFieldMetadata createTestItem() { String name = randomAlphaOfLengthBetween(3, 10); String inferenceId = randomIdentifier(); String searchInferenceId = randomIdentifier(); String[] inputFields = generateRandomStringArray(5, 10, false, false); - return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields); + Map chunkingSettings = generateRandomChunkingSettings(); + return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings); + } + + public static Map generateRandomChunkingSettings() { + if (randomBoolean()) { + return null; // Defaults to model chunking settings + } + return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings(); + } + + private static Map generateRandomWordBoundaryChunkingSettings() { + return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50)); + } + + private static Map generateRandomSentenceBoundaryChunkingSettings() { + return Map.of( + "strategy", + "sentence_boundary", + "max_chunk_size", + randomIntBetween(20, 100), + "sentence_overlap", + randomIntBetween(0, 1) + ); } public void testNullCtorArgsThrowException() { - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0])); - assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null)); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of()) + ); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of()) + ); + assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of())); + assertThrows( + NullPointerException.class, + () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of()) + ); } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java index 755b83e8eb7ad..93ac31c9ba582 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/MappingLookupInferenceFieldMapperTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.Query; import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.plugins.MapperPlugin; import org.elasticsearch.plugins.Plugin; @@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf @Override public InferenceFieldMetadata getMetadata(Set sourcePaths) { - return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0])); + return new InferenceFieldMetadata( + fullPath(), + INFERENCE_ID, + SEARCH_INFERENCE_ID, + sourcePaths.toArray(new String[0]), + InferenceFieldMetadataTests.generateRandomChunkingSettings() + ); } @Override diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java index 3c29cef47d628..7d4a120668a8b 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/AbstractTestInferenceService.java @@ -13,6 +13,9 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; @@ -22,14 +25,20 @@ import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunker; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import java.io.IOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Random; public abstract class AbstractTestInferenceService implements InferenceService { + protected record ChunkedInput(String input, int startOffset, int endOffset) {} + protected static final Random random = new Random( System.getProperty("tests.seed") == null ? System.currentTimeMillis() @@ -105,6 +114,34 @@ public void start(Model model, TimeValue timeout, ActionListener listen @Override public void close() throws IOException {} + protected List chunkInputs(ChunkInferenceInput input) { + ChunkingSettings chunkingSettings = input.chunkingSettings(); + String inputText = input.input(); + if (chunkingSettings == null) { + return List.of(new ChunkedInput(inputText, 0, inputText.length())); + } + + List chunkedInputs = new ArrayList<>(); + if (chunkingSettings.getChunkingStrategy() == ChunkingStrategy.WORD) { + WordBoundaryChunker chunker = new WordBoundaryChunker(); + WordBoundaryChunkingSettings wordBoundaryChunkingSettings = (WordBoundaryChunkingSettings) chunkingSettings; + List offsets = chunker.chunk( + inputText, + wordBoundaryChunkingSettings.maxChunkSize(), + wordBoundaryChunkingSettings.overlap() + ); + for (WordBoundaryChunker.ChunkOffset offset : offsets) { + chunkedInputs.add(new ChunkedInput(inputText.substring(offset.start(), offset.end()), offset.start(), offset.end())); + } + + } else { + // Won't implement till we need it + throw new UnsupportedOperationException("Test inference service only supports word chunking strategies"); + } + + return chunkedInputs; + } + public static class TestServiceModel extends Model { public TestServiceModel( diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ad6f1b88de328..044af0ab1d37d 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -35,7 +36,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import java.io.IOException; @@ -147,7 +147,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -176,21 +176,20 @@ private TextEmbeddingFloatResults makeResults(List input, ServiceSetting return new TextEmbeddingFloatResults(embeddings); } - private List makeChunkedResults(List input, ServiceSettings serviceSettings) { - TextEmbeddingFloatResults nonChunkedResults = makeResults(input, serviceSettings); - + private List makeChunkedResults(List inputs, ServiceSettings serviceSettings) { var results = new ArrayList(); - for (int i = 0; i < input.size(); i++) { - results.add( - new ChunkedInferenceEmbedding( - List.of( - new EmbeddingResults.Chunk( - nonChunkedResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) + for (ChunkInferenceInput input : inputs) { + List chunkedInput = chunkInputs(input); + List chunks = chunkedInput.stream() + .map( + c -> new TextEmbeddingFloatResults.Chunk( + makeResults(List.of(c.input()), serviceSettings).embeddings().get(0), + new ChunkedInference.TextOffset(c.startOffset(), c.endOffset()) ) ) - ); + .toList(); + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index d4e3642affddb..989726443ecf4 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -136,7 +137,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index 6f533d83884ea..03c5c6201ce33 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -33,7 +34,6 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; -import org.elasticsearch.xpack.core.inference.results.EmbeddingResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; @@ -45,6 +45,7 @@ import java.util.Map; public class TestSparseInferenceServiceExtension implements InferenceServiceExtension { + @Override public List getInferenceServiceFactories() { return List.of(TestInferenceService::new); @@ -137,7 +138,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, @@ -166,23 +167,20 @@ private SparseEmbeddingResults makeResults(List input) { return new SparseEmbeddingResults(embeddings); } - private List makeChunkedResults(List input) { + private List makeChunkedResults(List inputs) { List results = new ArrayList<>(); - for (int i = 0; i < input.size(); i++) { - var tokens = new ArrayList(); - for (int j = 0; j < 5; j++) { - tokens.add(new WeightedToken("feature_" + j, generateEmbedding(input.get(i), j))); - } - results.add( - new ChunkedInferenceEmbedding( - List.of( - new EmbeddingResults.Chunk( - new SparseEmbeddingResults.Embedding(tokens, false), - new ChunkedInference.TextOffset(0, input.get(i).length()) - ) - ) - ) - ); + for (ChunkInferenceInput chunkInferenceInput : inputs) { + List chunkedInput = chunkInputs(chunkInferenceInput); + List chunks = chunkedInput.stream().map(c -> { + var tokens = new ArrayList(); + for (int i = 0; i < 5; i++) { + tokens.add(new WeightedToken("feature_" + i, generateEmbedding(c.input(), i))); + } + var embeddings = new SparseEmbeddingResults.Embedding(tokens, false); + return new SparseEmbeddingResults.Chunk(embeddings, new ChunkedInference.TextOffset(c.startOffset(), c.endOffset())); + }).toList(); + ChunkedInferenceEmbedding chunkedInferenceEmbedding = new ChunkedInferenceEmbedding(chunks); + results.add(chunkedInferenceEmbedding); } return results; } diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6bcec22bb50b3..2320429f20704 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InferenceServiceExtension; @@ -260,7 +261,7 @@ public Iterator toXContentChunked(ToXContent.Params params public void chunkedInfer( Model model, String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index 78f30e7da0670..d41aa654f59e6 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -42,6 +42,7 @@ exports org.elasticsearch.xpack.inference.services; exports org.elasticsearch.xpack.inference; exports org.elasticsearch.xpack.inference.action.task; + exports org.elasticsearch.xpack.inference.chunking; exports org.elasticsearch.xpack.inference.telemetry; provides org.elasticsearch.features.FeatureSpecification with org.elasticsearch.xpack.inference.InferenceFeatures; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index 6544c86869434..ae5fc602babb9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -15,6 +15,7 @@ import java.util.Set; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_FILTER_FIX; import static org.elasticsearch.xpack.inference.queries.SemanticKnnVectorQueryRewriteInterceptor.SEMANTIC_KNN_VECTOR_QUERY_REWRITE_INTERCEPTION_SUPPORTED; import static org.elasticsearch.xpack.inference.queries.SemanticMatchQueryRewriteInterceptor.SEMANTIC_MATCH_QUERY_REWRITE_INTERCEPTION_SUPPORTED; @@ -55,7 +56,8 @@ public Set getTestFeatures() { TEST_RERANKING_SERVICE_PARSE_TEXT_AS_SCORE, SemanticTextFieldMapper.SEMANTIC_TEXT_BIT_VECTOR_SUPPORT, SemanticTextFieldMapper.SEMANTIC_TEXT_HANDLE_EMPTY_INPUT, - TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS + TEST_RULE_RETRIEVER_WITH_INDICES_THAT_DONT_RETURN_RANK_DOCS, + SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index 59555bfed4a17..dff611592dfae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -35,7 +35,9 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InputType; @@ -54,6 +56,7 @@ import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceError; import org.elasticsearch.xpack.inference.InferenceException; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextUtils; @@ -172,6 +175,7 @@ private record InferenceProvider(InferenceService service, Model model) {} * @param input The input to run inference on. * @param inputOrder The original order of the input. * @param offsetAdjustment The adjustment to apply to the chunk text offsets. + * @param chunkingSettings Additional explicitly specified chunking settings, or null to use model defaults */ private record FieldInferenceRequest( int bulkItemIndex, @@ -179,7 +183,8 @@ private record FieldInferenceRequest( String sourceField, String input, int inputOrder, - int offsetAdjustment + int offsetAdjustment, + ChunkingSettings chunkingSettings ) {} /** @@ -355,7 +360,10 @@ public void onFailure(Exception exc) { modelRegistry.getModelWithSecrets(inferenceId, modelLoadingListener); return; } - final List inputs = requests.stream().map(FieldInferenceRequest::input).collect(Collectors.toList()); + final List inputs = requests.stream() + .map(r -> new ChunkInferenceInput(r.input, r.chunkingSettings)) + .collect(Collectors.toList()); + ActionListener> completionListener = new ActionListener<>() { @Override public void onResponse(List results) { @@ -450,6 +458,7 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); String inferenceId = entry.getInferenceId(); + ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(entry.getChunkingSettings(), false); if (useLegacyFormat) { var originalFieldValue = XContentMapValues.extractValue(field, docMap); @@ -524,7 +533,9 @@ private long addFieldInferenceRequests(BulkItemRequest item, int itemIndex, Map< new FieldInferenceResponse(field, sourceField, v, order++, 0, null, EMPTY_CHUNKED_INFERENCE) ); } else { - requests.add(new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment)); + requests.add( + new FieldInferenceRequest(itemIndex, field, sourceField, v, order++, offsetAdjustment, chunkingSettings) + ); } // When using the inference metadata fields format, all the input values are concatenated so that the @@ -605,6 +616,7 @@ private void applyInferenceResponses(BulkItemRequest item, FieldInferenceRespons new SemanticTextField.InferenceResult( inferenceFieldMetadata.getInferenceId(), model != null ? new MinimalServiceSettings(model) : null, + ChunkingSettingsBuilder.fromMap(inferenceFieldMetadata.getChunkingSettings(), false), chunkMap ), indexRequest.getContentType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java index 2ede1684e315b..25553a4c760f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilder.java @@ -10,6 +10,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.ChunkingStrategy; +import java.util.HashMap; import java.util.Map; public class ChunkingSettingsBuilder { @@ -18,13 +19,24 @@ public class ChunkingSettingsBuilder { public static final WordBoundaryChunkingSettings OLD_DEFAULT_SETTINGS = new WordBoundaryChunkingSettings(250, 100); public static ChunkingSettings fromMap(Map settings) { - if (settings == null) { - return OLD_DEFAULT_SETTINGS; - } + return fromMap(settings, true); + } - if (settings.isEmpty()) { - return DEFAULT_SETTINGS; + public static ChunkingSettings fromMap(Map settings, boolean returnDefaultValues) { + + if (returnDefaultValues) { + if (settings == null) { + return OLD_DEFAULT_SETTINGS; + } + if (settings.isEmpty()) { + return DEFAULT_SETTINGS; + } + } else { + if (settings == null || settings.isEmpty()) { + return null; + } } + if (settings.containsKey(ChunkingSettingsOptions.STRATEGY.toString()) == false) { throw new IllegalArgumentException("Can't generate Chunker without ChunkingStrategy provided"); } @@ -33,8 +45,8 @@ public static ChunkingSettings fromMap(Map settings) { settings.get(ChunkingSettingsOptions.STRATEGY.toString()).toString() ); return switch (chunkingStrategy) { - case WORD -> WordBoundaryChunkingSettings.fromMap(settings); - case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(settings); + case WORD -> WordBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); + case SENTENCE -> SentenceBoundaryChunkingSettings.fromMap(new HashMap<>(settings)); }; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java index 636ab00aa09db..2df2f1e62f89a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunker.java @@ -10,8 +10,10 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.AtomicArray; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; +import org.elasticsearch.inference.ChunkingStrategy; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -21,6 +23,8 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Objects; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.Supplier; @@ -40,9 +44,9 @@ public class EmbeddingRequestChunker> { // Visible for testing - record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { - String chunkText() { - return inputs.get(inputIndex).substring(chunk.start(), chunk.end()); + record Request(int inputIndex, int chunkIndex, ChunkOffset chunk, List inputs) { + public String chunkText() { + return inputs.get(inputIndex).input().substring(chunk.start(), chunk.end()); } } @@ -54,8 +58,7 @@ public Supplier> inputs() { public record BatchRequestAndListener(BatchRequest batch, ActionListener listener) {} - private static final int DEFAULT_WORDS_PER_CHUNK = 250; - private static final int DEFAULT_CHUNK_OVERLAP = 100; + private static final ChunkingSettings DEFAULT_CHUNKING_SETTINGS = new WordBoundaryChunkingSettings(250, 100); // The maximum number of chunks that is stored for any input text. // If the configured chunker chunks the text into more chunks, each @@ -72,28 +75,44 @@ public record BatchRequestAndListener(BatchRequest batch, ActionListener resultsErrors; private ActionListener> finalListener; - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch) { this(inputs, maxNumberOfInputsPerBatch, null); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { + public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, int wordsPerChunk, int chunkOverlap) { this(inputs, maxNumberOfInputsPerBatch, new WordBoundaryChunkingSettings(wordsPerChunk, chunkOverlap)); } - public EmbeddingRequestChunker(List inputs, int maxNumberOfInputsPerBatch, ChunkingSettings chunkingSettings) { + public EmbeddingRequestChunker( + List inputs, + int maxNumberOfInputsPerBatch, + ChunkingSettings defaultChunkingSettings + ) { this.resultEmbeddings = new ArrayList<>(inputs.size()); this.resultOffsetStarts = new ArrayList<>(inputs.size()); this.resultOffsetEnds = new ArrayList<>(inputs.size()); this.resultsErrors = new AtomicArray<>(inputs.size()); - if (chunkingSettings == null) { - chunkingSettings = new WordBoundaryChunkingSettings(DEFAULT_WORDS_PER_CHUNK, DEFAULT_CHUNK_OVERLAP); + if (defaultChunkingSettings == null) { + defaultChunkingSettings = DEFAULT_CHUNKING_SETTINGS; } - Chunker chunker = ChunkerBuilder.fromChunkingStrategy(chunkingSettings.getChunkingStrategy()); + + Map chunkers = inputs.stream() + .map(ChunkInferenceInput::chunkingSettings) + .filter(Objects::nonNull) + .map(ChunkingSettings::getChunkingStrategy) + .distinct() + .collect(Collectors.toMap(chunkingStrategy -> chunkingStrategy, ChunkerBuilder::fromChunkingStrategy)); + Chunker defaultChunker = ChunkerBuilder.fromChunkingStrategy(defaultChunkingSettings.getChunkingStrategy()); List allRequests = new ArrayList<>(); for (int inputIndex = 0; inputIndex < inputs.size(); inputIndex++) { - List chunks = chunker.chunk(inputs.get(inputIndex), chunkingSettings); + ChunkingSettings chunkingSettings = inputs.get(inputIndex).chunkingSettings(); + if (chunkingSettings == null) { + chunkingSettings = defaultChunkingSettings; + } + Chunker chunker = chunkers.getOrDefault(chunkingSettings.getChunkingStrategy(), defaultChunker); + List chunks = chunker.chunk(inputs.get(inputIndex).input(), chunkingSettings); int resultCount = Math.min(chunks.size(), MAX_CHUNKS); resultEmbeddings.add(new AtomicReferenceArray<>(resultCount)); resultOffsetStarts.add(new ArrayList<>(resultCount)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java index 9d6f5bb89218f..6eb16d00748f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/SentenceBoundaryChunkingSettings.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -54,6 +55,18 @@ public SentenceBoundaryChunkingSettings(StreamInput in) throws IOException { } } + @Override + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.SENTENCE_OVERLAP.toString(), + sentenceOverlap + ); + } + public static SentenceBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -141,4 +154,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java index 7e0378d5b0cd1..97f8aa49ef4d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/chunking/WordBoundaryChunkingSettings.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.Arrays; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -48,6 +49,26 @@ public WordBoundaryChunkingSettings(StreamInput in) throws IOException { overlap = in.readInt(); } + @Override + public Map asMap() { + return Map.of( + ChunkingSettingsOptions.STRATEGY.toString(), + STRATEGY.toString().toLowerCase(Locale.ROOT), + ChunkingSettingsOptions.MAX_CHUNK_SIZE.toString(), + maxChunkSize, + ChunkingSettingsOptions.OVERLAP.toString(), + overlap + ); + } + + public int maxChunkSize() { + return maxChunkSize; + } + + public int overlap() { + return overlap; + } + public static WordBoundaryChunkingSettings fromMap(Map map) { ValidationException validationException = new ValidationException(); @@ -130,4 +151,9 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(maxChunkSize, overlap); } + + @Override + public String toString() { + return Strings.toString(this); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java index 5fb30b7df3d2e..270e18db8a624 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchEmbeddingsRequestManager.java @@ -71,7 +71,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchEmbeddingsRequest request = new AlibabaCloudSearchEmbeddingsRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java index 6f79cd3cf71e4..1d0c780b05510 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchSparseRequestManager.java @@ -71,7 +71,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); AlibabaCloudSearchSparseRequest request = new AlibabaCloudSearchSparseRequest(account, docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java index fd04950ed6459..33a83de2486c5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var serviceSettings = embeddingsModel.getServiceSettings(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index 32a29490d7d24..afef9ca87ad63 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -48,7 +48,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index 21fd154cc604f..80d7cfe9ad86e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -62,7 +62,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index fb147047fb07e..f0e76fe4dbef3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java index 27730bc02df07..6b6cfff37fa20 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ElasticInferenceServiceSparseEmbeddingsRequestManager.java @@ -67,7 +67,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java index 3f59386082d73..1e188d0f7bf5b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/EmbeddingsInput.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ChunkInferenceInput; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InputType; import java.util.List; import java.util.Objects; import java.util.function.Supplier; +import java.util.stream.Collectors; public class EmbeddingsInput extends InferenceInputs { @@ -24,30 +27,42 @@ public static EmbeddingsInput of(InferenceInputs inferenceInputs) { return (EmbeddingsInput) inferenceInputs; } - private final Supplier> listSupplier; + private final Supplier> listSupplier; private final InputType inputType; - public EmbeddingsInput(List input, @Nullable InputType inputType) { + public EmbeddingsInput(List input, @Nullable InputType inputType) { this(input, inputType, false); } - public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { + public EmbeddingsInput(Supplier> inputSupplier, @Nullable InputType inputType) { super(false); this.listSupplier = Objects.requireNonNull(inputSupplier); this.inputType = inputType; } - public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { + public EmbeddingsInput(List input, @Nullable ChunkingSettings chunkingSettings, @Nullable InputType inputType) { + this(input.stream().map(i -> new ChunkInferenceInput(i, chunkingSettings)).collect(Collectors.toList()), inputType, false); + } + + public EmbeddingsInput(List input, @Nullable InputType inputType, boolean stream) { super(stream); Objects.requireNonNull(input); this.listSupplier = () -> input; this.inputType = inputType; } - public List getInputs() { + public List getInputs() { return this.listSupplier.get(); } + public static EmbeddingsInput fromStrings(List input, @Nullable InputType inputType) { + return new EmbeddingsInput(input, null, inputType); + } + + public List getStringInputs() { + return getInputs().stream().map(ChunkInferenceInput::input).collect(Collectors.toList()); + } + public InputType getInputType() { return this.inputType; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index 49388112d39e9..278e198cb2548 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -55,7 +55,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java index c5e227358b47d..dd6582ef71d26 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java @@ -63,7 +63,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 2072bb305c078..07888f7c2ec9d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -60,7 +60,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java index 4d3063d9ab673..a3a2291c9e210 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxEmbeddingsRequestManager.java @@ -53,7 +53,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); execute( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java index 85f97ed45a869..b59b8770906d1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/JinaAIEmbeddingsRequestManager.java @@ -51,7 +51,7 @@ public void execute( ActionListener listener ) { EmbeddingsInput input = EmbeddingsInput.of(inferenceInputs); - List docsInput = input.getInputs(); + List docsInput = input.getStringInputs(); InputType inputType = input.getInputType(); JinaAIEmbeddingsRequest request = new JinaAIEmbeddingsRequest(docsInput, inputType, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index b224d5ceb14ae..6f77fa3f6e985 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -57,7 +57,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List docsInput = EmbeddingsInput.of(inferenceInputs).getInputs(); + List docsInput = EmbeddingsInput.of(inferenceInputs).getStringInputs(); var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java index 4a485f87858aa..c39387d647f77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/TruncatingRequestManager.java @@ -52,7 +52,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getInputs(); + var docsInput = inferenceInputs.castTo(EmbeddingsInput.class).getStringInputs(); var truncatedInput = truncate(docsInput, maxInputTokens); var request = requestCreator.apply(truncatedInput); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index 7ed17d34ae8b3..b6652e499b9fc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -15,6 +15,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.DeprecationHandler; @@ -27,6 +28,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.support.MapXContentParser; +import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import java.io.IOException; import java.util.ArrayList; @@ -69,8 +71,14 @@ public record SemanticTextField( static final String CHUNKED_START_OFFSET_FIELD = "start_offset"; static final String CHUNKED_END_OFFSET_FIELD = "end_offset"; static final String MODEL_SETTINGS_FIELD = "model_settings"; + static final String CHUNKING_SETTINGS_FIELD = "chunking_settings"; - public record InferenceResult(String inferenceId, MinimalServiceSettings modelSettings, Map> chunks) {} + public record InferenceResult( + String inferenceId, + MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, + Map> chunks + ) {} public record Chunk(@Nullable String text, int startOffset, int endOffset, BytesReference rawEmbeddings) {} @@ -120,6 +128,18 @@ static MinimalServiceSettings parseModelSettingsFromMap(Object node) { } } + static ChunkingSettings parseChunkingSettingsFromMap(Object node) { + if (node == null) { + return null; + } + try { + Map map = XContentMapValues.nodeMapValue(node, CHUNKING_SETTINGS_FIELD); + return ChunkingSettingsBuilder.fromMap(map, false); + } catch (Exception exc) { + throw new ElasticsearchException(exc); + } + } + @Override public List originalValues() { return originalValues != null ? originalValues : Collections.emptyList(); @@ -135,6 +155,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(INFERENCE_FIELD); builder.field(INFERENCE_ID_FIELD, inference.inferenceId); builder.field(MODEL_SETTINGS_FIELD, inference.modelSettings); + if (inference.chunkingSettings != null) { + builder.field(CHUNKING_SETTINGS_FIELD, inference.chunkingSettings); + } + if (useLegacyFormat) { builder.startArray(CHUNKS_FIELD); } else { @@ -178,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser SEMANTIC_TEXT_FIELD_PARSER = new ConstructingObjectParser<>(SemanticTextFieldMapper.CONTENT_TYPE, true, (args, context) -> { List originalValues = (List) args[0]; + InferenceResult inferenceResult = (InferenceResult) args[1]; if (context.useLegacyFormat() == false) { if (originalValues != null && originalValues.isEmpty() == false) { throw new IllegalArgumentException("Unknown field [" + TEXT_FIELD + "]"); @@ -188,7 +213,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws context.useLegacyFormat(), context.fieldName(), originalValues, - (InferenceResult) args[1], + inferenceResult, context.xContentType() ); }); @@ -197,7 +222,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws private static final ConstructingObjectParser INFERENCE_RESULT_PARSER = new ConstructingObjectParser<>( INFERENCE_FIELD, true, - args -> new InferenceResult((String) args[0], (MinimalServiceSettings) args[1], (Map>) args[2]) + args -> { + String inferenceId = (String) args[0]; + MinimalServiceSettings modelSettings = (MinimalServiceSettings) args[1]; + Map chunkingSettings = (Map) args[2]; + Map> chunks = (Map>) args[3]; + return new InferenceResult(inferenceId, modelSettings, ChunkingSettingsBuilder.fromMap(chunkingSettings, false), chunks); + } ); private static final ConstructingObjectParser CHUNKS_PARSER = new ConstructingObjectParser<>( @@ -218,11 +249,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws INFERENCE_RESULT_PARSER.declareString(constructorArg(), new ParseField(INFERENCE_ID_FIELD)); INFERENCE_RESULT_PARSER.declareObjectOrNull( - constructorArg(), + optionalConstructorArg(), (p, c) -> MinimalServiceSettings.parse(p), null, new ParseField(MODEL_SETTINGS_FIELD) ); + INFERENCE_RESULT_PARSER.declareObjectOrNull( + optionalConstructorArg(), + (p, c) -> p.map(), + null, + new ParseField(CHUNKING_SETTINGS_FIELD) + ); INFERENCE_RESULT_PARSER.declareField(constructorArg(), (p, c) -> { if (c.useLegacyFormat()) { return Map.of(c.fieldName, parseChunksArrayLegacy(p, c)); @@ -297,7 +334,7 @@ public static List toSemanticTextFieldChunksLegacy(String input, ChunkedI return chunks; } - public static Chunk toSemanticTextFieldChunkLegacy(String input, ChunkedInference.Chunk chunk) { + public static Chunk toSemanticTextFieldChunkLegacy(String input, org.elasticsearch.inference.ChunkedInference.Chunk chunk) { var text = input.substring(chunk.textOffset().start(), chunk.textOffset().end()); return new Chunk(text, -1, -1, chunk.bytesReference()); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 24372dc1539a6..3a942a8e73537 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -58,6 +58,7 @@ import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.SimilarityMeasure; @@ -94,6 +95,7 @@ import static org.elasticsearch.search.SearchService.DEFAULT_SIZE; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_EMBEDDINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKED_OFFSET_FIELD; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKING_SETTINGS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.CHUNKS_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_FIELD; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.INFERENCE_ID_FIELD; @@ -120,6 +122,7 @@ public class SemanticTextFieldMapper extends FieldMapper implements InferenceFie public static final NodeFeature SEMANTIC_TEXT_HANDLE_EMPTY_INPUT = new NodeFeature("semantic_text.handle_empty_input"); public static final NodeFeature SEMANTIC_TEXT_SKIP_INFERENCE_FIELDS = new NodeFeature("semantic_text.skip_inference_fields"); public static final NodeFeature SEMANTIC_TEXT_BIT_VECTOR_SUPPORT = new NodeFeature("semantic_text.bit_vector_support"); + public static final NodeFeature SEMANTIC_TEXT_SUPPORT_CHUNKING_CONFIG = new NodeFeature("semantic_text.support_chunking_config"); public static final String CONTENT_TYPE = "semantic_text"; public static final String DEFAULT_ELSER_2_INFERENCE_ID = DEFAULT_ELSER_ID; @@ -177,6 +180,17 @@ public static class Builder extends FieldMapper.Builder { Objects::toString ).acceptsNull().setMergeValidator(SemanticTextFieldMapper::canMergeModelSettings); + @SuppressWarnings("unchecked") + private final Parameter chunkingSettings = new Parameter<>( + CHUNKING_SETTINGS_FIELD, + true, + () -> null, + (n, c, o) -> SemanticTextField.parseChunkingSettingsFromMap(o), + mapper -> ((SemanticTextFieldType) mapper.fieldType()).chunkingSettings, + XContentBuilder::field, + Objects::toString + ).acceptsNull(); + private final Parameter> meta = Parameter.metaParam(); private Function inferenceFieldBuilder; @@ -219,9 +233,14 @@ public Builder setModelSettings(MinimalServiceSettings value) { return this; } + public Builder setChunkingSettings(ChunkingSettings value) { + this.chunkingSettings.setValue(value); + return this; + } + @Override protected Parameter[] getParameters() { - return new Parameter[] { inferenceId, searchInferenceId, modelSettings, meta }; + return new Parameter[] { inferenceId, searchInferenceId, modelSettings, chunkingSettings, meta }; } @Override @@ -263,6 +282,7 @@ public SemanticTextFieldMapper build(MapperBuilderContext context) { inferenceId.getValue(), searchInferenceId.getValue(), modelSettings.getValue(), + chunkingSettings.getValue(), inferenceField, useLegacyFormat, meta.getValue() @@ -521,7 +541,10 @@ public InferenceFieldMetadata getMetadata(Set sourcePaths) { String[] copyFields = sourcePaths.toArray(String[]::new); // ensure consistent order Arrays.sort(copyFields); - return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields); + ChunkingSettings fieldTypeChunkingSettings = fieldType().getChunkingSettings(); + Map asMap = fieldTypeChunkingSettings != null ? fieldTypeChunkingSettings.asMap() : null; + + return new InferenceFieldMetadata(fullPath(), fieldType().getInferenceId(), fieldType().getSearchInferenceId(), copyFields, asMap); } @Override @@ -550,6 +573,7 @@ public static class SemanticTextFieldType extends SimpleMappedFieldType { private final String inferenceId; private final String searchInferenceId; private final MinimalServiceSettings modelSettings; + private final ChunkingSettings chunkingSettings; private final ObjectMapper inferenceField; private final boolean useLegacyFormat; @@ -558,6 +582,7 @@ public SemanticTextFieldType( String inferenceId, String searchInferenceId, MinimalServiceSettings modelSettings, + ChunkingSettings chunkingSettings, ObjectMapper inferenceField, boolean useLegacyFormat, Map meta @@ -566,6 +591,7 @@ public SemanticTextFieldType( this.inferenceId = inferenceId; this.searchInferenceId = searchInferenceId; this.modelSettings = modelSettings; + this.chunkingSettings = chunkingSettings; this.inferenceField = inferenceField; this.useLegacyFormat = useLegacyFormat; } @@ -601,6 +627,10 @@ public MinimalServiceSettings getModelSettings() { return modelSettings; } + public ChunkingSettings getChunkingSettings() { + return chunkingSettings; + } + public ObjectMapper getInferenceField() { return inferenceField; } @@ -873,7 +903,7 @@ public List fetchValues(Source source, int doc, List ignoredValu useLegacyFormat, name(), null, - new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkMap), + new SemanticTextField.InferenceResult(inferenceId, modelSettings, chunkingSettings, chunkMap), source.sourceContentType() ) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index ddde0699ec6c7..ff8ae6fd5aac3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -14,6 +14,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceResults; @@ -70,29 +71,31 @@ public void infer( ActionListener listener ) { init(); - var inferenceInput = createInput(this, model, input, inputType, query, returnDocuments, topN, stream); + var chunkInferenceInput = input.stream().map(i -> new ChunkInferenceInput(i, null)).toList(); + var inferenceInput = createInput(this, model, chunkInferenceInput, inputType, query, returnDocuments, topN, stream); doInfer(model, inferenceInput, taskSettings, timeout, listener); } private static InferenceInputs createInput( SenderService service, Model model, - List input, + List input, InputType inputType, @Nullable String query, @Nullable Boolean returnDocuments, @Nullable Integer topN, boolean stream ) { + List textInput = ChunkInferenceInput.inputs(input); return switch (model.getTaskType()) { - case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(input, stream); + case COMPLETION, CHAT_COMPLETION -> new ChatCompletionInput(textInput, stream); case RERANK -> { ValidationException validationException = new ValidationException(); service.validateRerankParameters(returnDocuments, topN, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - yield new QueryAndDocsInputs(query, input, returnDocuments, topN, stream); + yield new QueryAndDocsInputs(query, textInput, returnDocuments, topN, stream); } case TEXT_EMBEDDING, SPARSE_EMBEDDING -> { ValidationException validationException = new ValidationException(); @@ -124,7 +127,7 @@ public void unifiedCompletionInfer( public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 1c2fddc6f5264..ac0d0df06b48d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -344,7 +344,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = alibabaCloudSearchModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 145865699d74f..38d8d61873ce5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -170,7 +170,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index d78a33ded916a..a70f44b91f9f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -141,7 +141,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = baseAzureAiStudioModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index a47e0e5d55972..03778e4471042 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -294,7 +294,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = azureOpenAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index 2995a3a71bff7..66dc7a1de9a75 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -307,7 +307,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = cohereModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index 3e5258b1850bd..438b0ca478f9f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -435,7 +435,7 @@ public void checkModelConfig(Model model, ActionListener listener) { private static List translateToChunkedResults(InferenceInputs inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = EmbeddingsInput.of(inputs).getStringInputs(); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index de9ab21039a2b..ed331e9df658e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -20,6 +20,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceResults; @@ -727,22 +728,11 @@ public void inferRerank( client.execute(InferModelAction.INSTANCE, request, maybeDeployListener); } - public void chunkedInfer( - Model model, - List input, - Map taskSettings, - InputType inputType, - TimeValue timeout, - ActionListener> listener - ) { - chunkedInfer(model, null, input, taskSettings, inputType, timeout, listener); - } - @Override public void chunkedInfer( Model model, @Nullable String query, - List input, + List input, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 1be510330fa85..87e0a1ba67a90 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -364,7 +364,13 @@ protected void doChunkedInfer( ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { - doInfer(model, new EmbeddingsInput(request.batch().inputs(), inputType), taskSettings, timeout, request.listener()); + doInfer( + model, + EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), + taskSettings, + timeout, + request.listener() + ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 391db8059a261..8526e8abbad4d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -250,7 +250,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = googleVertexAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 086403acc5cf1..612748f6ede12 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -135,7 +135,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = huggingFaceModel.accept(actionCreator); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 60a7f0bfd44b1..8116eaf86e74a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -112,7 +113,7 @@ protected void doChunkedInfer( private static List translateToChunkedResults(EmbeddingsInput inputs, InferenceServiceResults inferenceResults) { if (inferenceResults instanceof TextEmbeddingFloatResults textEmbeddingResults) { - validateInputSizeAgainstEmbeddings(inputs.getInputs(), textEmbeddingResults.embeddings().size()); + validateInputSizeAgainstEmbeddings(ChunkInferenceInput.inputs(inputs.getInputs()), textEmbeddingResults.embeddings().size()); var results = new ArrayList(inputs.getInputs().size()); @@ -122,7 +123,7 @@ private static List translateToChunkedResults(EmbeddingsInput List.of( new EmbeddingResults.Chunk( textEmbeddingResults.embeddings().get(i), - new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).length()) + new ChunkedInference.TextOffset(0, inputs.getInputs().get(i).input().length()) ) ) ) @@ -130,7 +131,7 @@ private static List translateToChunkedResults(EmbeddingsInput } return results; } else if (inferenceResults instanceof SparseEmbeddingResults sparseEmbeddingResults) { - var inputsAsList = EmbeddingsInput.of(inputs).getInputs(); + var inputsAsList = ChunkInferenceInput.inputs(EmbeddingsInput.of(inputs).getInputs()); return ChunkedInferenceEmbedding.listOf(inputsAsList, sparseEmbeddingResults); } else if (inferenceResults instanceof ErrorInferenceResults error) { return List.of(new ChunkedInferenceError(error.getException())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 6d9fb9a585149..c01d4d142fe16 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -315,11 +315,11 @@ protected void doChunkedInfer( var batchedRequests = new EmbeddingRequestChunker<>( input.getInputs(), EMBEDDING_MAX_BATCH_SIZE, - model.getConfigurations().getChunkingSettings() + ibmWatsonxModel.getConfigurations().getChunkingSettings() ).batchRequestsWithListeners(listener); for (var request : batchedRequests) { var action = ibmWatsonxModel.accept(getActionCreator(getSender(), getServiceComponents()), taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java index 5e6477bc5396c..afd1d5db213bf 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIService.java @@ -288,7 +288,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = jinaaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index 9fd0b4ad906c0..5c6488bfbbda2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -123,7 +123,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } else { listener.onFailure(createInvalidModelException(model)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index b41aa9a27af26..094b6b27e158b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -347,7 +347,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = openAiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java index 75b58f497d0e9..229266a5e51ed 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIService.java @@ -308,7 +308,7 @@ protected void doChunkedInfer( for (var request : batchedRequests) { var action = voyageaiModel.accept(actionCreator, taskSettings); - action.execute(new EmbeddingsInput(request.batch().inputs(), inputType), timeout, request.listener()); + action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java index 80617bf04f905..9b9b9f5fdb67d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreator.java @@ -57,7 +57,7 @@ public ExecutableAction create(VoyageAIEmbeddingsModel model, Map new VoyageAIEmbeddingsRequest( - embeddingsInput.getInputs(), + embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), overriddenModel ), diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java index 073bf8f5afb9a..270cdba6d3469 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticKnnVectorQueryRewriteInterceptorTests.java @@ -56,7 +56,7 @@ public void cleanup() { public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(INFERENCE_ID, QUERY); @@ -67,7 +67,7 @@ public void testKnnQueryWithVectorBuilderIsInterceptedAndRewritten() throws IOEx public void testKnnWithQueryBuilderWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), INFERENCE_ID, new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryVectorBuilder queryVectorBuilder = new TextEmbeddingQueryVectorBuilder(null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java index 47705c14d5941..6987ef33ed63d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticMatchQueryRewriteInterceptorTests.java @@ -52,7 +52,7 @@ public void cleanup() { public void testMatchQueryOnInferenceFieldIsInterceptedAndRewrittenToSemanticQuery() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = createTestQueryBuilder(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java index 1adad1df7b29b..075955766a0a9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/index/query/SemanticSparseVectorQueryRewriteInterceptorTests.java @@ -54,7 +54,7 @@ public void cleanup() { public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, INFERENCE_ID, QUERY); @@ -78,7 +78,7 @@ public void testSparseVectorQueryOnInferenceFieldIsInterceptedAndRewritten() thr public void testSparseVectorQueryOnInferenceFieldWithoutInferenceIdIsInterceptedAndRewritten() throws IOException { Map inferenceFields = Map.of( FIELD_NAME, - new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }) + new InferenceFieldMetadata(index.getName(), "inferenceId", new String[] { FIELD_NAME }, null) ); QueryRewriteContext context = createQueryRewriteContext(inferenceFields); QueryBuilder original = new SparseVectorQueryBuilder(FIELD_NAME, null, QUERY); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 74d5b9253f6a5..905acb7363f7b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -36,6 +36,7 @@ import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.mapper.InferenceMetadataFieldsMapper; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; @@ -138,7 +139,7 @@ public void testFilterNoop() throws Exception { new BulkItemRequest[0] ); request.setInferenceFieldMap( - Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false))) + Map.of("foo", new InferenceFieldMetadata("foo", "bar", "baz", generateRandomStringArray(5, 10, false, false), null)) ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); @@ -171,7 +172,7 @@ public void testLicenseInvalidForInference() throws InterruptedException { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[1]; items[0] = new BulkItemRequest(0, new IndexRequest("test").source("obj.field1", "Test")); @@ -211,11 +212,11 @@ public void testInferenceNotFound() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }), + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null), "field2", - new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }), + new InferenceFieldMetadata("field2", "inference_0", new String[] { "field2" }, null), "field3", - new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }) + new InferenceFieldMetadata("field3", "inference_0", new String[] { "field3" }, null) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -283,7 +284,7 @@ public void testItemFailures() throws Exception { Map inferenceFieldMap = Map.of( "field1", - new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }) + new InferenceFieldMetadata("field1", model.getInferenceEntityId(), new String[] { "field1" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; items[0] = new BulkItemRequest(0, new IndexRequest("index").source("field1", "I am a failure")); @@ -351,7 +352,7 @@ public void testExplicitNull() throws Exception { Map inferenceFieldMap = Map.of( "obj.field1", - new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }) + new InferenceFieldMetadata("obj.field1", model.getInferenceEntityId(), new String[] { "obj.field1" }, null) ); Map sourceWithNull = new HashMap<>(); sourceWithNull.put("field1", null); @@ -407,7 +408,7 @@ public void testHandleEmptyInput() throws Exception { Task task = mock(Task.class); Map inferenceFieldMap = Map.of( "semantic_text_field", - new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }) + new InferenceFieldMetadata("semantic_text_field", model.getInferenceEntityId(), new String[] { "semantic_text_field" }, null) ); BulkItemRequest[] items = new BulkItemRequest[3]; @@ -434,7 +435,7 @@ public void testManyRandomDocs() throws Exception { for (int i = 0; i < numInferenceFields; i++) { String field = randomAlphaOfLengthBetween(5, 10); String inferenceId = randomFrom(inferenceModelMap.keySet()); - inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field })); + inferenceFieldMap.put(field, new InferenceFieldMetadata(field, inferenceId, new String[] { field }, null)); } int numRequests = atLeast(100); @@ -508,12 +509,12 @@ private static ShardBulkInferenceActionFilter createFilter( InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[2]; + List inputs = (List) invocationOnMock.getArguments()[2]; ActionListener> listener = (ActionListener>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { List results = new ArrayList<>(); - for (String input : inputs) { - results.add(model.getResults(input)); + for (ChunkInferenceInput input : inputs) { + results.add(model.getResults(input.input())); } listener.onResponse(results); }; @@ -607,13 +608,14 @@ private static BulkItemRequest[] randomBulkItemRequest( useLegacyFormat, field, model, + null, List.of(inputText), results, requestContentType ); } else { Map> inputTextMap = Map.of(field, List.of(inputText)); - semanticTextField = randomSemanticText(useLegacyFormat, field, model, List.of(inputText), requestContentType); + semanticTextField = randomSemanticText(useLegacyFormat, field, model, null, List.of(inputText), requestContentType); model.putResult(inputText, toChunkedResult(useLegacyFormat, inputTextMap, semanticTextField)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java index 4a284e0a84ff5..9e6dde60bc641 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/ChunkingSettingsBuilderTests.java @@ -21,14 +21,18 @@ public class ChunkingSettingsBuilderTests extends ESTestCase { public void testNullChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(null); - assertEquals(ChunkingSettingsBuilder.OLD_DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(null, false); + assertNull(chunkingSettingsOrNull); } public void testEmptyChunkingSettingsMap() { ChunkingSettings chunkingSettings = ChunkingSettingsBuilder.fromMap(Collections.emptyMap()); - assertEquals(DEFAULT_SETTINGS, chunkingSettings); + + ChunkingSettings chunkingSettingsOrNull = ChunkingSettingsBuilder.fromMap(Map.of(), false); + assertNull(chunkingSettingsOrNull); } public void testChunkingStrategyNotProvided() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java index 066e39385c6e2..299c6e8c4f67d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/chunking/EmbeddingRequestChunkerTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.chunking; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.inference.results.ChunkedInferenceEmbedding; @@ -46,22 +47,27 @@ public void testEmptyInput_SentenceChunker() { } public void testWhitespaceInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(" "), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput(" ")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is(" ")); } public void testBlankInput_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("")); } public void testBlankInput_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of(""), 10, new SentenceBoundaryChunkingSettings(250, 1)) + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("")), 10, new SentenceBoundaryChunkingSettings(250, 1)) .batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); @@ -69,36 +75,45 @@ public void testBlankInput_SentenceChunker() { } public void testInputThatDoesNotChunk_WordChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 100, 100, 10).batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>(List.of(new ChunkInferenceInput("ABBAABBA")), 100, 100, 10).batchRequestsWithListeners( + testListener() + ); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testInputThatDoesNotChunk_SentenceChunker() { - var batches = new EmbeddingRequestChunker<>(List.of("ABBAABBA"), 10, new SentenceBoundaryChunkingSettings(250, 1)) - .batchRequestsWithListeners(testListener()); + var batches = new EmbeddingRequestChunker<>( + List.of(new ChunkInferenceInput("ABBAABBA")), + 10, + new SentenceBoundaryChunkingSettings(250, 1) + ).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); assertThat(batches.get(0).batch().inputs().get(), hasSize(1)); assertThat(batches.get(0).batch().inputs().get().get(0), Matchers.is("ABBAABBA")); } public void testShortInputsAreSingleBatch() { - String input = "one chunk"; + ChunkInferenceInput input = new ChunkInferenceInput("one chunk"); var batches = new EmbeddingRequestChunker<>(List.of(input), 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); - assertThat(batches.get(0).batch().inputs().get(), contains(input)); + assertThat(batches.get(0).batch().inputs().get(), contains(input.input())); } public void testMultipleShortInputsAreSingleBatch() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, 100, 100, 10).batchRequestsWithListeners(testListener()); assertThat(batches, hasSize(1)); EmbeddingRequestChunker.BatchRequest batch = batches.getFirst().batch(); - assertEquals(batch.inputs().get(), inputs); + assertEquals(batch.inputs().get(), ChunkInferenceInput.inputs(inputs)); for (int i = 0; i < inputs.size(); i++) { var request = batch.requests().get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertEquals(i, request.inputIndex()); assertEquals(0, request.chunkIndex()); } @@ -107,10 +122,10 @@ public void testMultipleShortInputsAreSingleBatch() { public void testManyInputsMakeManyBatches() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, 100, 10).batchRequestsWithListeners(testListener()); @@ -133,7 +148,7 @@ public void testManyInputsMakeManyBatches() { List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -142,10 +157,10 @@ public void testManyInputsMakeManyBatches() { public void testChunkingSettingsProvided() { int maxNumInputsPerBatch = 10; int numInputs = maxNumInputsPerBatch * 3 + 1; // requires 4 batches - var inputs = new ArrayList(); + var inputs = new ArrayList(); for (int i = 0; i < numInputs; i++) { - inputs.add("input " + i); + inputs.add(new ChunkInferenceInput("input " + i)); } var batches = new EmbeddingRequestChunker<>(inputs, maxNumInputsPerBatch, ChunkingSettingsTests.createRandomChunkingSettings()) @@ -169,7 +184,7 @@ public void testChunkingSettingsProvided() { List requests = batches.get(0).batch().requests(); for (int i = 0; i < requests.size(); i++) { EmbeddingRequestChunker.Request request = requests.get(i); - assertThat(request.chunkText(), equalTo(inputs.get(i))); + assertThat(request.chunkText(), equalTo(inputs.get(i).input())); assertThat(request.inputIndex(), equalTo(i)); assertThat(request.chunkIndex(), equalTo(0)); } @@ -188,7 +203,12 @@ public void testLongInputChunkedOverMultipleBatches() { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(testListener()); @@ -244,7 +264,11 @@ public void testVeryLongInput_Sparse() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -278,7 +302,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); SparseEmbeddingResults.Embedding embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 1 / 16384f))); @@ -294,16 +318,19 @@ public void testVeryLongInput_Sparse() { // The first merged chunk consists of 20 small chunks (so 400 words) and the max // weight is the weight of the 20th small chunk (so 21/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 21 / 16384f))); // The last merged chunk consists of 19 small chunks (so 380 words) and the max // weight is the weight of the 10000th small chunk (so 10001/16384). - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10001 / 16384f))); @@ -313,7 +340,7 @@ public void testVeryLongInput_Sparse() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(SparseEmbeddingResults.Embedding.class)); embedding = (SparseEmbeddingResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.tokens(), contains(new WeightedToken("word", 10002 / 16384f))); @@ -329,7 +356,11 @@ public void testVeryLongInput_Float() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -362,7 +393,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); TextEmbeddingFloatResults.Embedding embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 1 / 16384f })); @@ -378,16 +409,19 @@ public void testVeryLongInput_Float() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2/16384 ... 21/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { (2 + 21) / (2 * 16384f) })); // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983/16384 ... 10001/16384. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new float[] { (9983 + 10001) / (2 * 16384f) })); @@ -397,7 +431,7 @@ public void testVeryLongInput_Float() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingFloatResults.Embedding.class)); embedding = (TextEmbeddingFloatResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new float[] { 10002 / 16384f })); @@ -413,7 +447,11 @@ public void testVeryLongInput_Byte() { passageBuilder.append("word").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small") + ); var finalListener = testListener(); List batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, 0) @@ -446,7 +484,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); ChunkedInferenceEmbedding chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("1st small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); TextEmbeddingByteResults.Embedding embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 1 })); @@ -462,8 +500,8 @@ public void testVeryLongInput_Byte() { // The first merged chunk consists of 20 small chunks (so 400 words) and the weight // is the average of the weights 2 ... 21, so 11.5, which is rounded to 12. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), startsWith("word0 word1 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(0).offset()), endsWith(" word398 word399")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 12 })); @@ -471,8 +509,11 @@ public void testVeryLongInput_Byte() { // The last merged chunk consists of 19 small chunks (so 380 words) and the weight // is the average of the weights 9983 ... 10001 modulo 256 (bytes overflowing), so // the average of -1, 0, 1, ... , 17, so 8. - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), startsWith(" word199620 word199621 ")); - assertThat(getMatchedText(inputs.get(1), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), + startsWith(" word199620 word199621 ") + ); + assertThat(getMatchedText(inputs.get(1).input(), chunkedEmbedding.chunks().get(511).offset()), endsWith(" word199998 word199999")); assertThat(chunkedEmbedding.chunks().get(511).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(511).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 8 })); @@ -482,7 +523,7 @@ public void testVeryLongInput_Byte() { assertThat(inference, instanceOf(ChunkedInferenceEmbedding.class)); chunkedEmbedding = (ChunkedInferenceEmbedding) inference; assertThat(chunkedEmbedding.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedEmbedding.chunks().get(0).offset()), equalTo("2nd small")); assertThat(chunkedEmbedding.chunks().get(0).embedding(), instanceOf(TextEmbeddingByteResults.Embedding.class)); embedding = (TextEmbeddingByteResults.Embedding) chunkedEmbedding.chunks().get(0).embedding(); assertThat(embedding.values(), equalTo(new byte[] { 18 })); @@ -500,7 +541,12 @@ public void testMergingListener_Float() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -529,7 +575,7 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -537,26 +583,29 @@ public void testMergingListener_Float() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedFloatResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat( + getMatchedText(inputs.get(1).input(), chunkedFloatResult.chunks().get(5).offset()), + startsWith(" passage_input100 ") + ); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedFloatResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedFloatResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedFloatResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -572,7 +621,12 @@ public void testMergingListener_Byte() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -601,7 +655,7 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -609,26 +663,26 @@ public void testMergingListener_Byte() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -644,7 +698,12 @@ public void testMergingListener_Bit() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", passageBuilder.toString(), "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput(passageBuilder.toString()), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -673,7 +732,7 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("1st small")); } { // this is the large input split in multiple chunks @@ -681,26 +740,26 @@ public void testMergingListener_Bit() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(6)); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); - assertThat(getMatchedText(inputs.get(1), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(1).offset()), startsWith(" passage_input20 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(2).offset()), startsWith(" passage_input40 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(3).offset()), startsWith(" passage_input60 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(4).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedByteResult.chunks().get(5).offset()), startsWith(" passage_input100 ")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(3); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedByteResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedByteResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(3), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedByteResult.chunks().get(0).offset()), equalTo("3rd small")); } } @@ -716,7 +775,12 @@ public void testMergingListener_Sparse() { for (int i = 0; i < numberOfWordsInPassage; i++) { passageBuilder.append("passage_input").append(i).append(" "); // chunk on whitespace } - List inputs = List.of("1st small", "2nd small", "3rd small", passageBuilder.toString()); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small"), + new ChunkInferenceInput(passageBuilder.toString()) + ); var finalListener = testListener(); var batches = new EmbeddingRequestChunker<>(inputs, batchSize, chunkSize, overlap).batchRequestsWithListeners(finalListener); @@ -752,21 +816,21 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(0), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); + assertThat(getMatchedText(inputs.get(0).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("1st small")); } { var chunkedResult = finalListener.results.get(1); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(1), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); + assertThat(getMatchedText(inputs.get(1).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("2nd small")); } { var chunkedResult = finalListener.results.get(2); assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(1)); - assertThat(getMatchedText(inputs.get(2), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); + assertThat(getMatchedText(inputs.get(2).input(), chunkedSparseResult.chunks().get(0).offset()), equalTo("3rd small")); } { // this is the large input split in multiple chunks @@ -774,14 +838,24 @@ public void testMergingListener_Sparse() { assertThat(chunkedResult, instanceOf(ChunkedInferenceEmbedding.class)); var chunkedSparseResult = (ChunkedInferenceEmbedding) chunkedResult; assertThat(chunkedSparseResult.chunks(), hasSize(9)); // passage is split into 9 chunks, 10 words each - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(1).offset()), startsWith(" passage_input10 ")); - assertThat(getMatchedText(inputs.get(3), chunkedSparseResult.chunks().get(8).offset()), startsWith(" passage_input80 ")); + assertThat(getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(0).offset()), startsWith("passage_input0 ")); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(1).offset()), + startsWith(" passage_input10 ") + ); + assertThat( + getMatchedText(inputs.get(3).input(), chunkedSparseResult.chunks().get(8).offset()), + startsWith(" passage_input80 ") + ); } } public void testListenerErrorsWithWrongNumberOfResponses() { - List inputs = List.of("1st small", "2nd small", "3rd small"); + List inputs = List.of( + new ChunkInferenceInput("1st small"), + new ChunkInferenceInput("2nd small"), + new ChunkInferenceInput("3rd small") + ); var failureMessage = new AtomicReference(); var listener = new ActionListener>() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index 6f335ab32f01c..0bfa640b0cded 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; @@ -64,7 +65,7 @@ public void testOneInputIsValid() { public void testMoreThanOneInput() { var badInput = mock(EmbeddingsInput.class); - var input = List.of("one", "two"); + var input = List.of(new ChunkInferenceInput("one"), new ChunkInferenceInput("two")); when(badInput.getInputs()).thenReturn(input); when(badInput.isSingleInput()).thenReturn(false); var actualException = new AtomicReference(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index b106848cb7ed1..77b8f18632126 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -119,7 +120,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), - new EmbeddingsInput(List.of("abc"), null), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index a232b7724ca98..d40ee517a1c51 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.Scheduler; @@ -61,7 +62,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener @@ -81,7 +82,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -105,7 +106,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -134,7 +135,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), threadPool, listener @@ -161,7 +162,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp var requestTask = new RequestTask( OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), TimeValue.timeValueMillis(1), mockThreadPool, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java index d30a4362e19d9..add130da2d368 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticInferenceMetadataFieldsRecoveryTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.mapper.SourceToParse; import org.elasticsearch.index.translog.Translog; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; @@ -42,6 +43,7 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertToXContentEquivalent; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingByte; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingFloat; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomChunkedInferenceEmbeddingSparse; @@ -51,12 +53,14 @@ public class SemanticInferenceMetadataFieldsRecoveryTests extends EngineTestCase { private final Model model1; private final Model model2; + private final ChunkingSettings chunkingSettings; private final boolean useSynthetic; private final boolean useIncludesExcludes; public SemanticInferenceMetadataFieldsRecoveryTests(boolean useSynthetic, boolean useIncludesExcludes) { this.model1 = TestModel.createRandomInstance(TaskType.TEXT_EMBEDDING, List.of(SimilarityMeasure.DOT_PRODUCT)); this.model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + this.chunkingSettings = generateRandomChunkingSettings(); this.useSynthetic = useSynthetic; this.useIncludesExcludes = useIncludesExcludes; } @@ -105,6 +109,10 @@ protected String defaultMapping() { builder.field("element_type", model1.getServiceSettings().elementType().name()); builder.field("service", model1.getConfigurations().getService()); builder.endObject(); + if (chunkingSettings != null) { + builder.field("chunking_settings"); + chunkingSettings.toXContent(builder, null); + } builder.endObject(); builder.startObject("semantic_2"); @@ -114,6 +122,10 @@ protected String defaultMapping() { builder.field("task_type", model2.getTaskType().name()); builder.field("service", model2.getConfigurations().getService()); builder.endObject(); + if (chunkingSettings != null) { + builder.field("chunking_settings"); + chunkingSettings.toXContent(builder, null); + } builder.endObject(); builder.endObject(); @@ -229,8 +241,8 @@ private BytesReference randomSource() throws IOException { false, builder, List.of( - randomSemanticText(false, "semantic_2", model2, randomInputs(), XContentType.JSON), - randomSemanticText(false, "semantic_1", model1, randomInputs(), XContentType.JSON) + randomSemanticText(false, "semantic_2", model2, chunkingSettings, randomInputs(), XContentType.JSON), + randomSemanticText(false, "semantic_1", model1, chunkingSettings, randomInputs(), XContentType.JSON) ) ); builder.endObject(); @@ -241,6 +253,7 @@ private static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -252,7 +265,15 @@ private static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs, false); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } private static List randomInputs() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index b7a2d34bedfe2..4d2a76f915af3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -54,6 +54,7 @@ import org.elasticsearch.index.mapper.vectors.XFeatureField; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.search.ESToParentBlockJoinQuery; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -90,6 +91,8 @@ import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getChunksFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextField.getEmbeddingsFieldName; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper.DEFAULT_ELSER_2_INFERENCE_ID; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettings; +import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.generateRandomChunkingSettingsOtherThan; import static org.elasticsearch.xpack.inference.mapper.SemanticTextFieldTests.randomSemanticText; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -178,7 +181,7 @@ protected IngestScriptSupport ingestScriptSupport() { @Override public MappedFieldType getMappedFieldType() { - return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, false, Map.of()); + return new SemanticTextFieldMapper.SemanticTextFieldType("field", "fake-inference-id", null, null, null, null, false, Map.of()); } @Override @@ -566,6 +569,15 @@ public void testUpdateSearchInferenceId() throws IOException { } private static void assertSemanticTextField(MapperService mapperService, String fieldName, boolean expectedModelSettings) { + assertSemanticTextField(mapperService, fieldName, expectedModelSettings, null); + } + + private static void assertSemanticTextField( + MapperService mapperService, + String fieldName, + boolean expectedModelSettings, + ChunkingSettings expectedChunkingSettings + ) { Mapper mapper = mapperService.mappingLookup().getMapper(fieldName); assertNotNull(mapper); assertThat(mapper, instanceOf(SemanticTextFieldMapper.class)); @@ -617,6 +629,13 @@ private static void assertSemanticTextField(MapperService mapperService, String } else { assertNull(semanticFieldMapper.fieldType().getModelSettings()); } + + if (expectedChunkingSettings != null) { + assertNotNull(semanticFieldMapper.fieldType().getChunkingSettings()); + assertEquals(expectedChunkingSettings, semanticFieldMapper.fieldType().getChunkingSettings()); + } else { + assertNull(semanticFieldMapper.fieldType().getChunkingSettings()); + } } private static void assertInferenceEndpoints( @@ -642,9 +661,22 @@ public void testSuccessfulParse() throws IOException { Model model1 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); Model model2 = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + ChunkingSettings chunkingSettings = null; // Some chunking settings configs can produce different Lucene docs counts XContentBuilder mapping = mapping(b -> { - addSemanticTextMapping(b, fieldName1, model1.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); - addSemanticTextMapping(b, fieldName2, model2.getInferenceEntityId(), setSearchInferenceId ? searchInferenceId : null); + addSemanticTextMapping( + b, + fieldName1, + model1.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); + addSemanticTextMapping( + b, + fieldName2, + model2.getInferenceEntityId(), + setSearchInferenceId ? searchInferenceId : null, + chunkingSettings + ); }); MapperService mapperService = createMapperService(mapping, useLegacyFormat); @@ -670,8 +702,15 @@ public void testSuccessfulParse() throws IOException { useLegacyFormat, b, List.of( - randomSemanticText(useLegacyFormat, fieldName1, model1, List.of("a b", "c"), XContentType.JSON), - randomSemanticText(useLegacyFormat, fieldName2, model2, List.of("d e f"), XContentType.JSON) + randomSemanticText( + useLegacyFormat, + fieldName1, + model1, + chunkingSettings, + List.of("a b", "c"), + XContentType.JSON + ), + randomSemanticText(useLegacyFormat, fieldName2, model2, chunkingSettings, List.of("d e f"), XContentType.JSON) ) ) ) @@ -752,7 +791,7 @@ public void testSuccessfulParse() throws IOException { public void testMissingInferenceId() throws IOException { final MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), useLegacyFormat ); @@ -778,8 +817,11 @@ public void testMissingInferenceId() throws IOException { assertThat(ex.getCause().getMessage(), containsString("Required [inference_id]")); } - public void testMissingModelSettings() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + public void testMissingModelSettingsAndChunks() throws IOException { + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -791,11 +833,15 @@ public void testMissingModelSettings() throws IOException { ) ) ); - assertThat(ex.getCause().getMessage(), containsString("Required [model_settings, chunks]")); + // Model settings may be null here so we only error on chunks + assertThat(ex.getCause().getMessage(), containsString("Required [chunks]")); } public void testMissingTaskType() throws IOException { - MapperService mapperService = createMapperService(mapping(b -> addSemanticTextMapping(b, "field", "my_id", null)), useLegacyFormat); + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, "field", "my_id", null, null)), + useLegacyFormat + ); IllegalArgumentException ex = expectThrows( DocumentParsingException.class, IllegalArgumentException.class, @@ -854,10 +900,43 @@ public void testDenseVectorElementType() throws IOException { assertMapperService.accept(byteMapperService, DenseVectorFieldMapper.ElementType.BYTE); } + public void testSettingAndUpdatingChunkingSettings() throws IOException { + Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); + final ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + String fieldName = "field"; + + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); + + MapperService mapperService = createMapperService( + mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, chunkingSettings)), + useLegacyFormat + ); + assertSemanticTextField(mapperService, fieldName, false, chunkingSettings); + + ChunkingSettings newChunkingSettings = generateRandomChunkingSettingsOtherThan(chunkingSettings); + merge(mapperService, mapping(b -> addSemanticTextMapping(b, fieldName, model.getInferenceEntityId(), null, newChunkingSettings))); + assertSemanticTextField(mapperService, fieldName, false, newChunkingSettings); + } + public void testModelSettingsRequiredWithChunks() throws IOException { // Create inference results where model settings are set to null and chunks are provided Model model = TestModel.createRandomInstance(TaskType.SPARSE_EMBEDDING); - SemanticTextField randomSemanticText = randomSemanticText(useLegacyFormat, "field", model, List.of("a"), XContentType.JSON); + ChunkingSettings chunkingSettings = generateRandomChunkingSettings(false); + SemanticTextField randomSemanticText = randomSemanticText( + useLegacyFormat, + "field", + model, + chunkingSettings, + List.of("a"), + XContentType.JSON + ); SemanticTextField inferenceResults = new SemanticTextField( randomSemanticText.useLegacyFormat(), randomSemanticText.fieldName(), @@ -865,13 +944,14 @@ public void testModelSettingsRequiredWithChunks() throws IOException { new SemanticTextField.InferenceResult( randomSemanticText.inference().inferenceId(), null, + randomSemanticText.inference().chunkingSettings(), randomSemanticText.inference().chunks() ), randomSemanticText.contentType() ); MapperService mapperService = createMapperService( - mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null)), + mapping(b -> addSemanticTextMapping(b, "field", model.getInferenceEntityId(), null, chunkingSettings)), useLegacyFormat ); SourceToParse source = source(b -> addSemanticTextInferenceResults(useLegacyFormat, b, List.of(inferenceResults))); @@ -910,7 +990,7 @@ private MapperService mapperServiceForFieldWithModelSettings( useLegacyFormat, fieldName, List.of(), - new SemanticTextField.InferenceResult(inferenceId, modelSettings, Map.of()), + new SemanticTextField.InferenceResult(inferenceId, modelSettings, generateRandomChunkingSettings(), Map.of()), XContentType.JSON ); XContentBuilder builder = JsonXContent.contentBuilder().startObject(); @@ -982,7 +1062,8 @@ private static void addSemanticTextMapping( XContentBuilder mappingBuilder, String fieldName, String inferenceId, - String searchInferenceId + String searchInferenceId, + ChunkingSettings chunkingSettings ) throws IOException { mappingBuilder.startObject(fieldName); mappingBuilder.field("type", SemanticTextFieldMapper.CONTENT_TYPE); @@ -990,6 +1071,11 @@ private static void addSemanticTextMapping( if (searchInferenceId != null) { mappingBuilder.field("search_inference_id", searchInferenceId); } + if (chunkingSettings != null) { + mappingBuilder.startObject("chunking_settings"); + mappingBuilder.mapContents(chunkingSettings.asMap()); + mappingBuilder.endObject(); + } mappingBuilder.endObject(); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java index 0bde2b275c82d..b4ac5c475d425 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapperTestUtils; import org.elasticsearch.inference.ChunkedInference; +import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.MinimalServiceSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; @@ -29,6 +30,8 @@ import org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatResults; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.core.utils.FloatConversionUtils; +import org.elasticsearch.xpack.inference.chunking.SentenceBoundaryChunkingSettings; +import org.elasticsearch.xpack.inference.chunking.WordBoundaryChunkingSettings; import org.elasticsearch.xpack.inference.model.TestModel; import java.io.IOException; @@ -71,7 +74,7 @@ protected void assertEqualInstances(SemanticTextField expectedInstance, Semantic assertThat(newInstance.originalValues(), equalTo(expectedInstance.originalValues())); assertThat(newInstance.inference().modelSettings(), equalTo(expectedInstance.inference().modelSettings())); assertThat(newInstance.inference().chunks().size(), equalTo(expectedInstance.inference().chunks().size())); - + assertThat(newInstance.inference().chunkingSettings(), equalTo(expectedInstance.inference().chunkingSettings())); MinimalServiceSettings modelSettings = newInstance.inference().modelSettings(); for (var entry : newInstance.inference().chunks().entrySet()) { var expectedChunks = expectedInstance.inference().chunks().get(entry.getKey()); @@ -119,6 +122,7 @@ protected SemanticTextField createTestInstance() { useLegacyFormat, NAME, TestModel.createRandomInstance(), + generateRandomChunkingSettings(), rawValues, randomFrom(XContentType.values()) ); @@ -248,6 +252,7 @@ public static SemanticTextField randomSemanticText( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, XContentType contentType ) throws IOException { @@ -259,13 +264,22 @@ public static SemanticTextField randomSemanticText( case SPARSE_EMBEDDING -> randomChunkedInferenceEmbeddingSparse(inputs); default -> throw new AssertionError("invalid task type: " + model.getTaskType().name()); }; - return semanticTextFieldFromChunkedInferenceResults(useLegacyFormat, fieldName, model, inputs, results, contentType); + return semanticTextFieldFromChunkedInferenceResults( + useLegacyFormat, + fieldName, + model, + chunkingSettings, + inputs, + results, + contentType + ); } public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( boolean useLegacyFormat, String fieldName, Model model, + ChunkingSettings chunkingSettings, List inputs, ChunkedInference results, XContentType contentType @@ -300,12 +314,30 @@ public static SemanticTextField semanticTextFieldFromChunkedInferenceResults( new SemanticTextField.InferenceResult( model.getInferenceEntityId(), new MinimalServiceSettings(model), + chunkingSettings, Map.of(fieldName, chunks) ), contentType ); } + public static ChunkingSettings generateRandomChunkingSettings() { + return generateRandomChunkingSettings(true); + } + + public static ChunkingSettings generateRandomChunkingSettings(boolean allowNull) { + if (allowNull && randomBoolean()) { + return null; // Use model defaults + } + return randomBoolean() + ? new WordBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(1, 10)) + : new SentenceBoundaryChunkingSettings(randomIntBetween(20, 100), randomIntBetween(0, 1)); + } + + public static ChunkingSettings generateRandomChunkingSettingsOtherThan(ChunkingSettings chunkingSettings) { + return randomValueOtherThan(chunkingSettings, () -> generateRandomChunkingSettings(false)); + } + /** * Returns a randomly generated object for Semantic Text tests purpose. */ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index 2231951a017f5..c4a6b92ac033c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -369,7 +369,7 @@ private static SourceToParse buildSemanticTextFieldWithInferenceResults( useLegacyFormat, SEMANTIC_TEXT_FIELD, null, - new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, Map.of(SEMANTIC_TEXT_FIELD, List.of())), + new SemanticTextField.InferenceResult(INFERENCE_ID, modelSettings, null, Map.of(SEMANTIC_TEXT_FIELD, List.of())), XContentType.JSON ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java index dfc5994167faf..a3acfbcfee35d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -514,7 +515,7 @@ public void testChunkedInfer_SparseEmbeddingChunkingSettingsNotSet() throws IOEx } private void testChunkedInfer(TaskType taskType, ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("foo", "bar"); + var input = List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java index 3ac706df819b4..97bac8582c1fc 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/action/AlibabaCloudSearchCompletionActionTests.java @@ -122,7 +122,7 @@ public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatComplet PlainActionFuture listener = new PlainActionFuture<>(); assertThrows(IllegalArgumentException.class, () -> { action.execute( - new EmbeddingsInput(List.of(randomAlphaOfLength(10)), InputType.INGEST), + new EmbeddingsInput(List.of(randomAlphaOfLength(10)), null, InputType.INGEST), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 5b06506515971..8d4ce151605a5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1501,7 +1502,7 @@ private void testChunkedInfer(AmazonBedrockEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java index 145a2e6078360..8b55d5b78f397 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/action/AmazonBedrockActionCreatorTests.java @@ -74,7 +74,7 @@ public void testEmbeddingsRequestAction_Titan() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -112,7 +112,7 @@ public void testEmbeddingsRequestAction_Cohere() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); @@ -145,7 +145,7 @@ public void testEmbeddingsRequestAction_HandlesException() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java index 232ed2d23c367..797d50878a0b7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; @@ -83,7 +84,7 @@ public void send( ) { sendCounter++; if (inferenceInputs instanceof EmbeddingsInput docsInput) { - inputs.add(docsInput.getInputs()); + inputs.add(ChunkInferenceInput.inputs(docsInput.getInputs())); if (docsInput.getInputType() != null) { inputTypes.add(docsInput.getInputType()); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java index c0acc4157bf68..eefa1785715cb 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; @@ -83,7 +84,7 @@ public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws threadPool, new TimeValue(30, TimeUnit.SECONDS) ); - sender.send(requestManager, new EmbeddingsInput(List.of("abc"), null), null, listener); + sender.send(requestManager, new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), null), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 00cfa2a53f8b3..c688803b43ff1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1235,7 +1236,7 @@ private void testChunkedInfer(AzureAiStudioEmbeddingsModel model) throws IOExcep service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java index d4436f449c454..fa27d926ec6f3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/action/AzureAiStudioActionAndCreatorTests.java @@ -115,7 +115,7 @@ public void testEmbeddingsRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index ffda34f0e8fdd..facdb90873d47 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1368,7 +1369,7 @@ private void testChunkedInfer(AzureOpenAiEmbeddingsModel model) throws IOExcepti service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java index b7b8477e7d3ba..d91e0be04f137 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -170,7 +170,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOExcepti var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -222,7 +222,7 @@ public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var failureCauseMessage = "Required [data]"; var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); @@ -296,7 +296,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -373,7 +373,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abcd"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abcd"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -433,7 +433,11 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("super long input"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of("super long input"), null, inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java index 319d8c074fe6a..b888a9e21e161 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/action/AzureOpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -116,7 +117,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -144,7 +149,11 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -165,7 +174,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -186,7 +199,11 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +218,11 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index b17a8b29bce26..1d0d921956b73 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1464,7 +1465,7 @@ private void testChunkedInfer(CohereEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1563,7 +1564,7 @@ public void testChunkedInfer_BatchesCalls_Bytes() throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java index da9e2e872981d..b56a19c0af0f1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereActionCreatorTests.java @@ -119,7 +119,7 @@ public void testCreate_CohereEmbeddingsModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java index 9c1f1e6ef2150..3a7fa5c8ec798 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/action/CohereEmbeddingsActionTests.java @@ -127,7 +127,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -216,7 +216,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I ); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -271,7 +271,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -291,7 +291,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -311,7 +311,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -325,7 +325,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -339,7 +339,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", CohereEmbeddingsTaskSettings.EMPTY_SETTINGS, null, null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 4f61269fcc6c2..6a93b1cc19c87 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.EmptySecretSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -632,7 +633,7 @@ public void testChunkedInfer_PropagatesProductUseCaseHeader() throws IOException service.chunkedInfer( model, null, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -760,7 +761,7 @@ public void testChunkedInfer_PassesThrough() throws IOException { service.chunkedInfer( model, null, - List.of("input text"), + List.of(new ChunkInferenceInput("input text")), new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java index a2966488742ef..fadf4a899e45d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/action/ElasticInferenceServiceActionCreatorTests.java @@ -96,7 +96,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -157,7 +157,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -214,7 +214,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -278,7 +278,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("hello world"), InputType.UNSPECIFIED), + new EmbeddingsInput(List.of("hello world"), null, InputType.UNSPECIFIED), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index a88cc21953a02..7067577b30189 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -25,6 +25,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -950,7 +951,7 @@ private void testChunkInfer_e5(ChunkingSettings chunkingSettings) throws Interru service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1022,7 +1023,7 @@ private void testChunkInfer_Sparse(ChunkingSettings chunkingSettings) throws Int service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1094,7 +1095,7 @@ private void testChunkInfer_Elser(ChunkingSettings chunkingSettings) throws Inte service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1139,7 +1140,8 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(null); service.chunkedInfer( model, - List.of("foo", "bar"), + null, + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1150,7 +1152,8 @@ public void testChunkInferSetsTokenization() { expectedWindowSize.set(256); service.chunkedInfer( model, - List.of("foo", "bar"), + null, + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1202,7 +1205,7 @@ public void testChunkInfer_FailsBatch() throws InterruptedException { service.chunkedInfer( model, null, - List.of("foo", "bar", "baz"), + List.of(new ChunkInferenceInput("foo"), new ChunkInferenceInput("bar"), new ChunkInferenceInput("baz")), Map.of(), InputType.SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1228,7 +1231,7 @@ public void testChunkingLargeDocument() throws InterruptedException { // build a doc with enough words to make numChunks of chunks int wordsPerChunk = 10; int numWords = numChunks * wordsPerChunk; - var input = "word ".repeat(numWords); + var input = new ChunkInferenceInput("word ".repeat(numWords), null); Client client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index a1430d36a0f5b..f7e228cb3044c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -892,7 +893,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbeddingsModel model) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -937,7 +938,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -952,7 +953,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -976,7 +977,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(0)))), + Map.of("parts", List.of(Map.of("text", input.get(0).input()))), "taskType", "RETRIEVAL_DOCUMENT" ), @@ -984,7 +985,7 @@ private void testChunkedInfer(String modelId, String apiKey, GoogleAiStudioEmbed "model", Strings.format("%s/%s", "models", modelId), "content", - Map.of("parts", List.of(Map.of("text", input.get(1)))), + Map.of("parts", List.of(Map.of("text", input.get(1).input()))), "taskType", "RETRIEVAL_DOCUMENT" ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java index 5fb2bd2bbc12b..9d514326462d5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/action/GoogleAiStudioEmbeddingsActionTests.java @@ -108,7 +108,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomWithNull(); - action.execute(new EmbeddingsInput(List.of(input), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of(input), null, inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -187,7 +187,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -205,7 +205,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java index 6a423f12684eb..739262c895ef1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiEmbeddingsActionTests.java @@ -75,7 +75,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -99,7 +99,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -117,7 +117,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java index 64c9cc366544c..e9406931f3bda 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googlevertexai/action/GoogleVertexAiRerankActionTests.java @@ -71,7 +71,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -91,7 +91,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -105,7 +105,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "projectId", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new EmbeddingsInput(List.of("abc"), null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new EmbeddingsInput(List.of("abc"), null, null), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java index 173e68536e22c..3ef0de5f30b0d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceElserServiceTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.InferenceServiceConfiguration; import org.elasticsearch.inference.InputType; @@ -96,7 +97,7 @@ public void testChunkedInfer_CallsInfer_Elser_ConvertsFloatResponse() throws IOE service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_SEARCH, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 9575321494923..1b4754f25e59a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -810,7 +811,7 @@ public void testChunkedInfer_CallsInfer_TextEmbedding_ConvertsFloatResponse() th service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -862,7 +863,7 @@ public void testChunkedInfer() throws IOException { service.chunkedInfer( model, null, - List.of("abc"), + List.of(new ChunkInferenceInput("abc")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java index eabeeff9b7b1a..09619cd076ff5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionCreatorTests.java @@ -95,7 +95,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForElserAction() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForElserAction() throws IOEx PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -218,7 +218,7 @@ public void testExecute_ReturnsSuccessfulResponse_ForEmbeddingsAction() throws I PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -280,7 +280,7 @@ public void testSend_FailsFromInvalidResponseFormat_ForEmbeddingsAction() throws PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -339,7 +339,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating() throws IOExc PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -401,7 +401,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("123456"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("123456"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java index 56b659466d289..d58308807902d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/action/HuggingFaceActionTests.java @@ -63,7 +63,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderThrows() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -87,7 +87,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -108,7 +108,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 17909bef41a3c..544871cef2184 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptyTaskSettings; @@ -731,7 +732,7 @@ public void testChunkedInfer_ChunkingSettingsSet() throws IOException { } private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws IOException { - var input = List.of("a", "bb"); + var input = List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")); var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -786,7 +787,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(0), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(0); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(0).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(0).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( @@ -801,7 +802,7 @@ private void testChunkedInfer_Batches(ChunkingSettings chunkingSettings) throws assertThat(results.get(1), instanceOf(ChunkedInferenceEmbedding.class)); var floatResult = (ChunkedInferenceEmbedding) results.get(1); assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(new ChunkedInference.TextOffset(0, input.get(1).length()), floatResult.chunks().get(0).offset()); + assertEquals(new ChunkedInference.TextOffset(0, input.get(1).input().length()), floatResult.chunks().get(0).offset()); assertThat(floatResult.chunks().get(0).embedding(), Matchers.instanceOf(TextEmbeddingFloatResults.Embedding.class)); assertTrue( Arrays.equals( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java index d3a92d5926a23..4eafa696f60e4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/action/IbmWatsonxEmbeddingsActionTests.java @@ -114,7 +114,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of(input), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(input), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -144,7 +144,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -173,7 +173,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java index e1446c36b893e..47937629fca3c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/jinaai/JinaAIServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1852,7 +1853,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(JinaAIEmbeddingsModel mode service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index db771f13cc0a6..cca1627e767c4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -19,6 +19,7 @@ import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -713,7 +714,7 @@ public void testChunkedInfer(MistralEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("abc", "def"), + List.of(new ChunkInferenceInput("abc"), new ChunkInferenceInput("def")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 2e85077968090..3229a750dee42 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1909,7 +1910,7 @@ private void testChunkedInfer(OpenAiEmbeddingsModel model) throws IOException { service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.INTERNAL_INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java index 102cdbec77d74..d7c72cf98e267 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiActionCreatorTests.java @@ -109,7 +109,7 @@ public void testCreate_OpenAiEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -166,7 +166,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutUser() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -222,7 +222,7 @@ public void testCreate_OpenAiEmbeddingsModel_WithoutOrganization() throws IOExce PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -285,7 +285,7 @@ public void testCreate_OpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() th PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -625,7 +625,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -712,7 +712,7 @@ public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusC PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abcd"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abcd"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -784,7 +784,7 @@ public void testExecute_TruncatesInputBeforeSending() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("super long input"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("super long input"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java index 4242c930356a5..acc98724436f9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/action/OpenAiEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -114,7 +115,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -154,7 +155,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of("abc"), null, InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -178,7 +179,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -202,7 +203,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -220,7 +221,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -238,7 +239,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java index 521d042bb8615..7155858a1ac03 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/VoyageAIServiceTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.ChunkedInference; import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.InferenceServiceConfiguration; @@ -1804,7 +1805,7 @@ private void test_Embedding_ChunkedInfer_BatchesCalls(VoyageAIEmbeddingsModel mo service.chunkedInfer( model, null, - List.of("a", "bb"), + List.of(new ChunkInferenceInput("a"), new ChunkInferenceInput("bb")), new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java index 90a3a35e35e81..10ed276c82a9b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIActionCreatorTests.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -110,7 +111,11 @@ public void testCreate_VoyageAIEmbeddingsModel() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java index e4ef60fc4ff32..47b6ac0c9dbd5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/voyageai/action/VoyageAIEmbeddingsActionTests.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkInferenceInput; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; @@ -120,7 +121,11 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -217,7 +222,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForInt8ResponseType() throws I PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -314,7 +323,11 @@ public void testExecute_ReturnsSuccessfulResponse_ForBinaryResponseType() throws PlainActionFuture listener = new PlainActionFuture<>(); var inputType = InputTypeTests.randomSearchAndIngestWithNull(); - action.execute(new EmbeddingsInput(List.of("abc"), inputType), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute( + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), inputType), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); var result = listener.actionGet(TIMEOUT); @@ -383,7 +396,7 @@ public void testExecute_ThrowsElasticsearchException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -407,7 +420,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -425,7 +438,7 @@ public void testExecute_ThrowsException() { PlainActionFuture listener = new PlainActionFuture<>(); action.execute( - new EmbeddingsInput(List.of("abc"), InputTypeTests.randomSearchAndIngestWithNull()), + new EmbeddingsInput(List.of(new ChunkInferenceInput("abc")), InputTypeTests.randomSearchAndIngestWithNull()), InferenceAction.Request.DEFAULT_TIMEOUT, listener ); @@ -448,7 +461,7 @@ private ExecutableAction createAction( threadPool, model, EMBEDDINGS_HANDLER, - (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getInputs(), embeddingsInput.getInputType(), model), + (embeddingsInput) -> new VoyageAIEmbeddingsRequest(embeddingsInput.getStringInputs(), embeddingsInput.getInputType(), model), EmbeddingsInput.class ); diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml new file mode 100644 index 0000000000000..a6ff307f0ef4a --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking.yml @@ -0,0 +1,523 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking-sparse + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: default-chunking-dense + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: custom-chunking-dense + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: default-chunking-sparse + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-sparse + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking-sparse + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-sparse + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + +--- +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": + + - do: + search: + index: default-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 3 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml new file mode 100644 index 0000000000000..f189d5535bb77 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/25_semantic_text_field_mapping_chunking_bwc.yml @@ -0,0 +1,550 @@ +setup: + - requires: + cluster_features: "semantic_text.support_chunking_config" + reason: semantic_text chunking configuration added in 8.19 + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "similarity": "cosine", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: default-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: default-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + + - do: + indices.create: + index: custom-chunking-sparse + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: custom-chunking-dense + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: default-chunking-sparse + id: doc_1 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-sparse + id: doc_2 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: default-chunking-dense + id: doc_3 + body: + keyword_field: "default sentence chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + index: + index: custom-chunking-dense + id: doc_4 + body: + keyword_field: "custom word chunking" + inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + +--- +"We return chunking configurations with mappings": + + - do: + indices.get_mapping: + index: default-chunking-sparse + + - not_exists: default-chunking-sparse.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-sparse + + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-sparse.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.get_mapping: + index: default-chunking-dense + + - not_exists: default-chunking-dense.mappings.properties.inference_field.chunking_settings + + - do: + indices.get_mapping: + index: custom-chunking-dense + + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "custom-chunking-dense.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + +--- +"We do not set custom chunking settings for null or empty specified chunking settings": + + - do: + indices.create: + index: null-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: dense-inference-id + chunking_settings: null + + - do: + indices.get_mapping: + index: null-chunking + + - not_exists: null-chunking.mappings.properties.inference_field.chunking_settings + + + - do: + indices.create: + index: empty-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: { } + + - do: + indices.get_mapping: + index: empty-chunking + + - not_exists: empty-chunking.mappings.properties.inference_field.chunking_settings + +--- +"We return different chunks based on configured chunking overrides or model defaults for sparse embeddings": + + - do: + search: + index: default-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-sparse + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_2" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We return different chunks based on configured chunking overrides or model defaults for dense embeddings": + + - do: + search: + index: default-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 1 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: custom-chunking-dense + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_4" } + - length: { hits.hits.0.highlight.inference_field: 3 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.inference_field.2: " enjoys all the features it provides." } + +--- +"We respect multiple semantic_text fields with different chunking configurations": + + - do: + indices.create: + index: mixed-chunking + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + default_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + customized_chunked_inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + index: + index: mixed-chunking + id: doc_1 + body: + default_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + customized_chunked_inference_field: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." + refresh: true + + - do: + search: + index: mixed-chunking + body: + query: + bool: + should: + - match: + default_chunked_inference_field: "What is Elasticsearch?" + - match: + customized_chunked_inference_field: "What is Elasticsearch?" + highlight: + fields: + default_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + customized_chunked_inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.default_chunked_inference_field: 1 } + - match: { hits.hits.0.highlight.default_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - length: { hits.hits.0.highlight.customized_chunked_inference_field: 3 } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.0.highlight.customized_chunked_inference_field.2: " enjoys all the features it provides." } + + +--- +"Bulk requests are handled appropriately": + + - do: + indices.create: + index: index1 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.create: + index: index2 + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + bulk: + refresh: true + body: | + { "index": { "_index": "index1", "_id": "doc_1" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index2", "_id": "doc_2" }} + { "inference_field": "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + { "index": { "_index": "index1", "_id": "doc_3" }} + { "inference_field": "Elasticsearch is a free, open-source search engine and analytics tool that stores and indexes data." } + + - do: + search: + index: index1,index2 + body: + query: + semantic: + field: "inference_field" + query: "What is Elasticsearch?" + highlight: + fields: + inference_field: + type: "semantic" + number_of_fragments: 3 + + - match: { hits.total.value: 3 } + + - match: { hits.hits.0._id: "doc_3" } + - length: { hits.hits.0.highlight.inference_field: 2 } + - match: { hits.hits.0.highlight.inference_field.0: "Elasticsearch is a free, open-source search engine and analytics" } + - match: { hits.hits.0.highlight.inference_field.1: " analytics tool that stores and indexes data." } + + - match: { hits.hits.1._id: "doc_1" } + - length: { hits.hits.1.highlight.inference_field: 3 } + - match: { hits.hits.1.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which" } + - match: { hits.hits.1.highlight.inference_field.1: " which is built on top of Lucene internally and enjoys" } + - match: { hits.hits.1.highlight.inference_field.2: " enjoys all the features it provides." } + + - match: { hits.hits.2._id: "doc_2" } + - length: { hits.hits.2.highlight.inference_field: 1 } + - match: { hits.hits.2.highlight.inference_field.0: "Elasticsearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + +--- +"Invalid chunking settings will result in an error": + + - do: + catch: /chunking settings can not have the following settings/ + indices.create: + index: invalid-chunking-extra-stuff + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + extra: stuff + + - do: + catch: /\[chunking_settings\] does not contain the required setting \[max_chunk_size\]/ + indices.create: + index: invalid-chunking-missing-required-settings + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + keyword_field: + type: keyword + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + + - do: + catch: /Invalid chunkingStrategy/ + indices.create: + index: invalid-chunking-invalid-strategy + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: invalid + +--- +"We can update chunking settings": + + - do: + indices.create: + index: chunking-update + body: + settings: + index.mapping.semantic_text.use_legacy_format: true + mappings: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + chunking_settings: + strategy: word + max_chunk_size: 10 + overlap: 1 + + - do: + indices.get_mapping: + index: chunking-update + + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.strategy": "word" } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.max_chunk_size": 10 } + - match: { "chunking-update.mappings.properties.inference_field.chunking_settings.overlap": 1 } + + - do: + indices.put_mapping: + index: chunking-update + body: + properties: + inference_field: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.get_mapping: + index: chunking-update + + - not_exists: chunking-update.mappings.properties.inference_field.chunking_settings + diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml index 27c405f6c23bf..35e472e72b06d 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/60_semantic_text_inference_update.yml @@ -79,7 +79,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -104,7 +104,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -129,7 +129,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -152,7 +152,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } # We can't directly check that the embeddings are different since there isn't a "does not match" assertion in the # YAML test framework. Check that the start and end offsets change as expected as a proxy. @@ -179,7 +179,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -202,7 +202,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -254,7 +254,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -283,7 +283,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -320,7 +320,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -367,7 +367,7 @@ setup: index: test-index id: doc_1 body: - doc: { "sparse_field": [{"key": "value"}], "dense_field": [{"key": "value"}] } + doc: { "sparse_field": [ { "key": "value" } ], "dense_field": [ { "key": "value" } ] } - match: { error.type: "status_exception" } - match: { error.reason: "/Invalid\\ format\\ for\\ field\\ \\[(dense|sparse)_field\\].+/" } @@ -415,7 +415,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -448,7 +448,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -509,7 +509,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } @@ -540,7 +540,7 @@ setup: body: fields: [ _inference_fields ] query: - match_all: {} + match_all: { } - match: { hits.total.value: 1 } - match: { hits.total.relation: eq } From d11e4427a398db85611befac0bc0aadff1df9935 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Fri, 4 Apr 2025 02:35:04 +0000 Subject: [PATCH 4/7] [CI] Auto commit changes from spotless --- .../core/ml/job/process/autodetect/state/ModelSizeStats.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java index 0757c9424b309..4891844c74a4c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java @@ -629,8 +629,7 @@ public Builder setPeakModelBytes(long peakModelBytes) { return this; } - public Builder setActualMemoryUsageBytes(long actualMemoryUsageBytes) - { + public Builder setActualMemoryUsageBytes(long actualMemoryUsageBytes) { this.actualMemoryUsageBytes = actualMemoryUsageBytes; return this; } From d67d5d31b9af6fc393c5c42dddb94bd6855a73dc Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 9 Apr 2025 16:03:47 +1200 Subject: [PATCH 5/7] Account for C++ code changes --- .../org/elasticsearch/TransportVersions.java | 2 +- .../autodetect/state/ModelSizeStats.java | 73 +++++++++++++------ .../job/persistence/JobResultsProvider.java | 13 +++- .../autodetect/AutodetectProcessManager.java | 4 +- .../AutodetectProcessManagerTests.java | 9 ++- 5 files changed, 70 insertions(+), 31 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index dfbd5a5fd557f..5a5f4cc15633e 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -214,7 +214,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_REMOVE_AGGREGATE_TYPE = def(9_045_0_00); public static final TransportVersion ADD_PROJECT_ID_TO_DSL_ERROR_INFO = def(9_046_0_00); public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG = def(9_047_00_0); - public static final TransportVersion ML_AD_ACTUAL_MEMORY_USAGE = def(9_048_0_00); + public static final TransportVersion ML_AD_SYSTEM_MEMORY_USAGE = def(9_048_0_00); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java index 4891844c74a4c..6d09f135d8b10 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/job/process/autodetect/state/ModelSizeStats.java @@ -41,7 +41,8 @@ public class ModelSizeStats implements ToXContentObject, Writeable { */ public static final ParseField MODEL_BYTES_FIELD = new ParseField("model_bytes"); public static final ParseField PEAK_MODEL_BYTES_FIELD = new ParseField("peak_model_bytes"); - public static final ParseField ACTUAL_MEMORY_USAGE_BYTES = new ParseField("actual_memory_usage_bytes"); + public static final ParseField SYSTEM_MEMORY_BYTES = new ParseField("system_memory_bytes"); + public static final ParseField MAX_SYSTEM_MEMORY_BYTES = new ParseField("max_system_memory_bytes"); public static final ParseField MODEL_BYTES_EXCEEDED_FIELD = new ParseField("model_bytes_exceeded"); public static final ParseField MODEL_BYTES_MEMORY_LIMIT_FIELD = new ParseField("model_bytes_memory_limit"); public static final ParseField TOTAL_BY_FIELD_COUNT_FIELD = new ParseField("total_by_field_count"); @@ -75,7 +76,8 @@ private static ConstructingObjectParser createParser(boolean igno parser.declareString((modelSizeStat, s) -> {}, Result.RESULT_TYPE); parser.declareLong(Builder::setModelBytes, MODEL_BYTES_FIELD); parser.declareLong(Builder::setPeakModelBytes, PEAK_MODEL_BYTES_FIELD); - parser.declareLong(Builder::setActualMemoryUsageBytes, ACTUAL_MEMORY_USAGE_BYTES); + parser.declareLong(Builder::setSystemMemoryBytes, SYSTEM_MEMORY_BYTES); + parser.declareLong(Builder::setMaxSystemMemoryBytes, MAX_SYSTEM_MEMORY_BYTES); parser.declareLong(Builder::setModelBytesExceeded, MODEL_BYTES_EXCEEDED_FIELD); parser.declareLong(Builder::setModelBytesMemoryLimit, MODEL_BYTES_MEMORY_LIMIT_FIELD); parser.declareLong(Builder::setBucketAllocationFailuresCount, BUCKET_ALLOCATION_FAILURES_COUNT_FIELD); @@ -154,7 +156,8 @@ public String toString() { * 1. The job's model_memory_limit * 2. The current model memory, i.e. what's reported in model_bytes of this object * 3. The peak model memory, i.e. what's reported in peak_model_bytes of this object - * 4. The actual memory usage, i.e. what's reported in actual_memory_usage_bytes of this object + * 4. The system memory, i.e. what's reported in system_memory_bytes of this object + * 5. The max system memory, i.e. what's reported in max_system_memory_bytes of this object * The field storing this enum can also be null, which means the * assignment code will decide on the fly - this was the old behaviour prior * to 7.11. @@ -163,7 +166,8 @@ public enum AssignmentMemoryBasis implements Writeable { MODEL_MEMORY_LIMIT, CURRENT_MODEL_BYTES, PEAK_MODEL_BYTES, - ACTUAL_MEMORY_USAGE_BYTES; + SYSTEM_MEMORY_BYTES, + MAX_SYSTEM_MEMORY_BYTES,; public static AssignmentMemoryBasis fromString(String statusName) { return valueOf(statusName.trim().toUpperCase(Locale.ROOT)); @@ -187,7 +191,8 @@ public String toString() { private final String jobId; private final long modelBytes; private final Long peakModelBytes; - private final Long actualMemoryUsageBytes; + private final Long systemMemoryUsageBytes; + private final Long maxSystemMemoryUsageBytes; private final Long modelBytesExceeded; private final Long modelBytesMemoryLimit; private final long totalByFieldCount; @@ -211,7 +216,8 @@ private ModelSizeStats( String jobId, long modelBytes, Long peakModelBytes, - Long actualMemoryUsageBytes, + Long systemMemoryUsageBytes, + Long maxSystemMemoryUsageBytes, Long modelBytesExceeded, Long modelBytesMemoryLimit, long totalByFieldCount, @@ -234,7 +240,8 @@ private ModelSizeStats( this.jobId = jobId; this.modelBytes = modelBytes; this.peakModelBytes = peakModelBytes; - this.actualMemoryUsageBytes = actualMemoryUsageBytes; + this.systemMemoryUsageBytes = systemMemoryUsageBytes; + this.maxSystemMemoryUsageBytes = maxSystemMemoryUsageBytes; this.modelBytesExceeded = modelBytesExceeded; this.modelBytesMemoryLimit = modelBytesMemoryLimit; this.totalByFieldCount = totalByFieldCount; @@ -259,10 +266,12 @@ public ModelSizeStats(StreamInput in) throws IOException { jobId = in.readString(); modelBytes = in.readVLong(); peakModelBytes = in.readOptionalLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) { - actualMemoryUsageBytes = in.readOptionalLong(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_AD_SYSTEM_MEMORY_USAGE)) { + systemMemoryUsageBytes = in.readOptionalLong(); + maxSystemMemoryUsageBytes = in.readOptionalLong(); } else { - actualMemoryUsageBytes = null; + systemMemoryUsageBytes = null; + maxSystemMemoryUsageBytes = null; } modelBytesExceeded = in.readOptionalLong(); modelBytesMemoryLimit = in.readOptionalLong(); @@ -305,8 +314,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(jobId); out.writeVLong(modelBytes); out.writeOptionalLong(peakModelBytes); - if (out.getTransportVersion().onOrAfter(TransportVersions.ML_AD_ACTUAL_MEMORY_USAGE)) { - out.writeOptionalLong(actualMemoryUsageBytes); + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_AD_SYSTEM_MEMORY_USAGE)) { + out.writeOptionalLong(systemMemoryUsageBytes); + out.writeOptionalLong(maxSystemMemoryUsageBytes); } out.writeOptionalLong(modelBytesExceeded); out.writeOptionalLong(modelBytesMemoryLimit); @@ -354,8 +364,11 @@ public XContentBuilder doXContentBody(XContentBuilder builder) throws IOExceptio if (peakModelBytes != null) { builder.field(PEAK_MODEL_BYTES_FIELD.getPreferredName(), peakModelBytes); } - if (actualMemoryUsageBytes != null) { - builder.field(ACTUAL_MEMORY_USAGE_BYTES.getPreferredName(), actualMemoryUsageBytes); + if (systemMemoryUsageBytes != null) { + builder.field(SYSTEM_MEMORY_BYTES.getPreferredName(), systemMemoryUsageBytes); + } + if (maxSystemMemoryUsageBytes != null) { + builder.field(MAX_SYSTEM_MEMORY_BYTES.getPreferredName(), maxSystemMemoryUsageBytes); } if (modelBytesExceeded != null) { builder.field(MODEL_BYTES_EXCEEDED_FIELD.getPreferredName(), modelBytesExceeded); @@ -409,8 +422,12 @@ public Long getPeakModelBytes() { return peakModelBytes; } - public Long getActualMemoryUsageBytes() { - return actualMemoryUsageBytes; + public Long getSystemMemoryBytes() { + return systemMemoryUsageBytes; + } + + public Long getMaxSystemMemoryBytes() { + return maxSystemMemoryUsageBytes; } public Long getModelBytesExceeded() { @@ -501,7 +518,8 @@ public int hashCode() { jobId, modelBytes, peakModelBytes, - actualMemoryUsageBytes, + systemMemoryUsageBytes, + maxSystemMemoryUsageBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, @@ -540,7 +558,8 @@ public boolean equals(Object other) { return this.modelBytes == that.modelBytes && Objects.equals(this.peakModelBytes, that.peakModelBytes) - && this.actualMemoryUsageBytes == that.actualMemoryUsageBytes + && this.systemMemoryUsageBytes == that.systemMemoryUsageBytes + && this.maxSystemMemoryUsageBytes == that.maxSystemMemoryUsageBytes && Objects.equals(this.modelBytesExceeded, that.modelBytesExceeded) && Objects.equals(this.modelBytesMemoryLimit, that.modelBytesMemoryLimit) && this.totalByFieldCount == that.totalByFieldCount @@ -567,7 +586,8 @@ public static class Builder { private final String jobId; private long modelBytes; private Long peakModelBytes; - private Long actualMemoryUsageBytes; + private Long systemMemoryUsageBytes; + private Long maxSystemMemoryUsageBytes; private Long modelBytesExceeded; private Long modelBytesMemoryLimit; private long totalByFieldCount; @@ -598,7 +618,8 @@ public Builder(ModelSizeStats modelSizeStats) { this.jobId = modelSizeStats.jobId; this.modelBytes = modelSizeStats.modelBytes; this.peakModelBytes = modelSizeStats.peakModelBytes; - this.actualMemoryUsageBytes = modelSizeStats.actualMemoryUsageBytes; + this.systemMemoryUsageBytes = modelSizeStats.systemMemoryUsageBytes; + this.maxSystemMemoryUsageBytes = modelSizeStats.maxSystemMemoryUsageBytes; this.modelBytesExceeded = modelSizeStats.modelBytesExceeded; this.modelBytesMemoryLimit = modelSizeStats.modelBytesMemoryLimit; this.totalByFieldCount = modelSizeStats.totalByFieldCount; @@ -629,8 +650,13 @@ public Builder setPeakModelBytes(long peakModelBytes) { return this; } - public Builder setActualMemoryUsageBytes(long actualMemoryUsageBytes) { - this.actualMemoryUsageBytes = actualMemoryUsageBytes; + public Builder setSystemMemoryBytes(long systemMemoryUsageBytes) { + this.systemMemoryUsageBytes = systemMemoryUsageBytes; + return this; + } + + public Builder setMaxSystemMemoryBytes(long maxSystemMemoryUsageBytes) { + this.maxSystemMemoryUsageBytes = maxSystemMemoryUsageBytes; return this; } @@ -731,7 +757,8 @@ public ModelSizeStats build() { jobId, modelBytes, peakModelBytes, - actualMemoryUsageBytes, + systemMemoryUsageBytes, + maxSystemMemoryUsageBytes, modelBytesExceeded, modelBytesMemoryLimit, totalByFieldCount, diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java index cc685e29e3d04..a2dffa395de34 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java @@ -1590,10 +1590,17 @@ void calculateEstablishedMemoryUsage( handler.accept((storedPeak != null) ? storedPeak : latestModelSizeStats.getModelBytes()); return; } - case ACTUAL_MEMORY_USAGE_BYTES -> { - Long storedActualMemoryUsageBytes = latestModelSizeStats.getActualMemoryUsageBytes(); + case SYSTEM_MEMORY_BYTES -> { + Long storedSystemMemoryBytes = latestModelSizeStats.getSystemMemoryBytes(); handler.accept( - (storedActualMemoryUsageBytes != null) ? storedActualMemoryUsageBytes : latestModelSizeStats.getModelBytes() + (storedSystemMemoryBytes != null) ? storedSystemMemoryBytes : latestModelSizeStats.getModelBytes() + ); + return; + } + case MAX_SYSTEM_MEMORY_BYTES -> { + Long storedMaxSystemMemoryBytes = latestModelSizeStats.getMaxSystemMemoryBytes(); + handler.accept( + (storedMaxSystemMemoryBytes != null) ? storedMaxSystemMemoryBytes : latestModelSizeStats.getModelBytes() ); return; } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java index ba10d0c8090f6..c4694dc02b7d8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManager.java @@ -1076,7 +1076,9 @@ public ByteSizeValue getOpenProcessMemoryUsage() { case MODEL_MEMORY_LIMIT -> Optional.ofNullable(modelSizeStats.getModelBytesMemoryLimit()).orElse(0L); case CURRENT_MODEL_BYTES -> modelSizeStats.getModelBytes(); case PEAK_MODEL_BYTES -> Optional.ofNullable(modelSizeStats.getPeakModelBytes()).orElse(modelSizeStats.getModelBytes()); - case ACTUAL_MEMORY_USAGE_BYTES -> Optional.ofNullable(modelSizeStats.getActualMemoryUsageBytes()) + case SYSTEM_MEMORY_BYTES -> Optional.ofNullable(modelSizeStats.getSystemMemoryBytes()) + .orElse(modelSizeStats.getModelBytes()); + case MAX_SYSTEM_MEMORY_BYTES -> Optional.ofNullable(modelSizeStats.getMaxSystemMemoryBytes()) .orElse(modelSizeStats.getModelBytes()); }; memoryUsedBytes += Job.PROCESS_MEMORY_OVERHEAD.getBytes(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java index 03aeb18ebd3b1..0a542ce16a8b9 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/process/autodetect/AutodetectProcessManagerTests.java @@ -834,12 +834,14 @@ public void testGetOpenProcessMemoryUsage() { long modelMemoryLimitBytes = ByteSizeValue.ofMb(randomIntBetween(10, 1000)).getBytes(); long peakModelBytes = randomLongBetween(100000, modelMemoryLimitBytes - 1); long modelBytes = randomLongBetween(1, peakModelBytes - 1); - long actualMemoryUsageBytes = randomLongBetween(262144, peakModelBytes - 1); + long systemMemoryUsageBytes = randomLongBetween(262144, peakModelBytes - 1); + long maxSystemMemoryUsageBytes = randomLongBetween(266240, peakModelBytes - 1); AssignmentMemoryBasis assignmentMemoryBasis = randomFrom(AssignmentMemoryBasis.values()); modelSizeStats = new ModelSizeStats.Builder("foo").setModelBytesMemoryLimit(modelMemoryLimitBytes) .setPeakModelBytes(peakModelBytes) .setModelBytes(modelBytes) - .setActualMemoryUsageBytes(actualMemoryUsageBytes) + .setSystemMemoryBytes(systemMemoryUsageBytes) + .setMaxSystemMemoryBytes(maxSystemMemoryUsageBytes) .setAssignmentMemoryBasis(assignmentMemoryBasis) .build(); when(autodetectCommunicator.getModelSizeStats()).thenReturn(modelSizeStats); @@ -852,7 +854,8 @@ public void testGetOpenProcessMemoryUsage() { case MODEL_MEMORY_LIMIT -> modelMemoryLimitBytes; case CURRENT_MODEL_BYTES -> modelBytes; case PEAK_MODEL_BYTES -> peakModelBytes; - case ACTUAL_MEMORY_USAGE_BYTES -> actualMemoryUsageBytes; + case SYSTEM_MEMORY_BYTES -> systemMemoryUsageBytes; + case MAX_SYSTEM_MEMORY_BYTES -> maxSystemMemoryUsageBytes; }; assertThat(manager.getOpenProcessMemoryUsage(), equalTo(ByteSizeValue.ofBytes(expectedSizeBytes))); } From 60528f7e26b15de23558c83467ca2f9cc3c0862d Mon Sep 17 00:00:00 2001 From: Ed Savage Date: Wed, 9 Apr 2025 16:07:51 +1200 Subject: [PATCH 6/7] Revert unneeded MLConfigVersion change --- .../java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java index 06e5ff5b8d08c..260409db0e653 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/MlConfigVersion.java @@ -153,13 +153,12 @@ private static void checkUniqueness(int id, String uniqueId) { // V_11 is used in ELSER v2 package configs public static final MlConfigVersion V_11 = registerMlConfigVersion(11_00_0_0_99, "79CB2950-57C7-11EE-AE5D-0800200C9A66"); public static final MlConfigVersion V_12 = registerMlConfigVersion(12_00_0_0_99, "Trained model config prefix strings added"); - public static final MlConfigVersion V_13 = registerMlConfigVersion(13_00_0_0_99, "Anomaly Detection reports actual memory usage"); /** * Reference to the most recent Ml config version. * This should be the Ml config version with the highest id. */ - public static final MlConfigVersion CURRENT = V_13; + public static final MlConfigVersion CURRENT = V_12; /** * Reference to the first MlConfigVersion that is detached from the From b490a5d5172d410174faab1cd59256f56ce3adb4 Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Wed, 9 Apr 2025 04:17:01 +0000 Subject: [PATCH 7/7] [CI] Auto commit changes from spotless --- .../xpack/ml/job/persistence/JobResultsProvider.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java index a2dffa395de34..185d50b1b3558 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/job/persistence/JobResultsProvider.java @@ -1592,9 +1592,7 @@ void calculateEstablishedMemoryUsage( } case SYSTEM_MEMORY_BYTES -> { Long storedSystemMemoryBytes = latestModelSizeStats.getSystemMemoryBytes(); - handler.accept( - (storedSystemMemoryBytes != null) ? storedSystemMemoryBytes : latestModelSizeStats.getModelBytes() - ); + handler.accept((storedSystemMemoryBytes != null) ? storedSystemMemoryBytes : latestModelSizeStats.getModelBytes()); return; } case MAX_SYSTEM_MEMORY_BYTES -> {