Skip to content

Commit 68f3a64

Browse files
[Inference API] Add special case to inference API (elastic#116962) (elastic#117035)
* Add reranker special case to inference API * Update docs/changelog/116962.yaml * Update 116962.yaml * spotless * improvements from review * Fix typo
1 parent cfb451e commit 68f3a64

File tree

6 files changed

+166
-1
lines changed

6 files changed

+166
-1
lines changed

docs/changelog/116962.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 116962
2+
summary: "Add special case for elastic reranker in inference API"
3+
area: Machine Learning
4+
type: enhancement
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
6464
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
6565
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
66+
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
6667
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
6768
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
6869
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
@@ -415,7 +416,13 @@ private static void addInternalNamedWriteables(List<NamedWriteableRegistry.Entry
415416
MultilingualE5SmallInternalServiceSettings::new
416417
)
417418
);
418-
419+
namedWriteables.add(
420+
new NamedWriteableRegistry.Entry(
421+
ServiceSettings.class,
422+
ElasticRerankerServiceSettings.NAME,
423+
ElasticRerankerServiceSettings::new
424+
)
425+
);
419426
}
420427

421428
private static void addChunkedInferenceResultsNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/BaseElasticsearchInternalService.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ protected void putModel(Model model, ActionListener<Boolean> listener) {
156156
putBuiltInModel(e5Model.getServiceSettings().modelId(), listener);
157157
} else if (model instanceof ElserInternalModel elserModel) {
158158
putBuiltInModel(elserModel.getServiceSettings().modelId(), listener);
159+
} else if (model instanceof ElasticRerankerModel elasticRerankerModel) {
160+
putBuiltInModel(elasticRerankerModel.getServiceSettings().modelId(), listener);
159161
} else if (model instanceof CustomElandModel) {
160162
logger.info("Custom eland model detected, model must have been already loaded into the cluster with eland.");
161163
listener.onResponse(Boolean.TRUE);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.ResourceNotFoundException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.inference.ChunkingSettings;
13+
import org.elasticsearch.inference.Model;
14+
import org.elasticsearch.inference.TaskType;
15+
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
16+
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
17+
18+
public class ElasticRerankerModel extends ElasticsearchInternalModel {
19+
20+
public ElasticRerankerModel(
21+
String inferenceEntityId,
22+
TaskType taskType,
23+
String service,
24+
ElasticRerankerServiceSettings serviceSettings,
25+
ChunkingSettings chunkingSettings
26+
) {
27+
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
28+
}
29+
30+
@Override
31+
public ElasticRerankerServiceSettings getServiceSettings() {
32+
return (ElasticRerankerServiceSettings) super.getServiceSettings();
33+
}
34+
35+
@Override
36+
public ActionListener<CreateTrainedModelAssignmentAction.Response> getCreateTrainedModelAssignmentActionListener(
37+
Model model,
38+
ActionListener<Boolean> listener
39+
) {
40+
41+
return new ActionListener<>() {
42+
@Override
43+
public void onResponse(CreateTrainedModelAssignmentAction.Response response) {
44+
listener.onResponse(Boolean.TRUE);
45+
}
46+
47+
@Override
48+
public void onFailure(Exception e) {
49+
if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
50+
listener.onFailure(
51+
new ResourceNotFoundException("Could not start the Elastic Reranker Endpoint due to [{}]", e, e.getMessage())
52+
);
53+
return;
54+
}
55+
listener.onFailure(e);
56+
}
57+
};
58+
}
59+
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.services.elasticsearch;
9+
10+
import org.elasticsearch.common.ValidationException;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings;
13+
14+
import java.io.IOException;
15+
import java.util.Map;
16+
17+
public class ElasticRerankerServiceSettings extends ElasticsearchInternalServiceSettings {
18+
19+
public static final String NAME = "elastic_reranker_service_settings";
20+
21+
public ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings other) {
22+
super(other);
23+
}
24+
25+
public ElasticRerankerServiceSettings(
26+
Integer numAllocations,
27+
int numThreads,
28+
String modelId,
29+
AdaptiveAllocationsSettings adaptiveAllocationsSettings
30+
) {
31+
super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings);
32+
}
33+
34+
public ElasticRerankerServiceSettings(StreamInput in) throws IOException {
35+
super(in);
36+
}
37+
38+
/**
39+
* Parse the ElasticRerankerServiceSettings from map and validate the setting values.
40+
*
41+
* If required setting are missing or the values are invalid an
42+
* {@link ValidationException} is thrown.
43+
*
44+
* @param map Source map containing the config
45+
* @return The builder
46+
*/
47+
public static Builder fromRequestMap(Map<String, Object> map) {
48+
ValidationException validationException = new ValidationException();
49+
var baseSettings = ElasticsearchInternalServiceSettings.fromMap(map, validationException);
50+
51+
if (validationException.validationErrors().isEmpty() == false) {
52+
throw validationException;
53+
}
54+
55+
return baseSettings;
56+
}
57+
58+
@Override
59+
public String getWriteableName() {
60+
return ElasticRerankerServiceSettings.NAME;
61+
}
62+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
9797
MULTILINGUAL_E5_SMALL_MODEL_ID_LINUX_X86
9898
);
9999

100+
public static final String RERANKER_ID = ".rerank-v1";
101+
100102
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
101103
public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
102104
public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
@@ -223,6 +225,8 @@ public void parseRequestConfig(
223225
)
224226
)
225227
);
228+
} else if (RERANKER_ID.equals(modelId)) {
229+
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
226230
} else {
227231
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
228232
}
@@ -323,6 +327,31 @@ private static CustomElandInternalServiceSettings elandServiceSettings(
323327
};
324328
}
325329

330+
private void rerankerCase(
331+
String inferenceEntityId,
332+
TaskType taskType,
333+
Map<String, Object> config,
334+
Map<String, Object> serviceSettingsMap,
335+
ChunkingSettings chunkingSettings,
336+
ActionListener<Model> modelListener
337+
) {
338+
339+
var esServiceSettingsBuilder = ElasticsearchInternalServiceSettings.fromRequestMap(serviceSettingsMap);
340+
341+
throwIfNotEmptyMap(config, name());
342+
throwIfNotEmptyMap(serviceSettingsMap, name());
343+
344+
modelListener.onResponse(
345+
new ElasticRerankerModel(
346+
inferenceEntityId,
347+
taskType,
348+
NAME,
349+
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
350+
chunkingSettings
351+
)
352+
);
353+
}
354+
326355
private void e5Case(
327356
String inferenceEntityId,
328357
TaskType taskType,

0 commit comments

Comments
 (0)