Skip to content

Commit 071e7ce

Browse files
authored
[ML] Move code specific to the Elasticsearch in cluster services to those sevices (#113749)
Remove the platform arch argument from parseRequest and move code used by internal services out of the transport action into the service.
1 parent 0fbb3bc commit 071e7ce

File tree

44 files changed

+294
-330
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+294
-330
lines changed

muted-tests.yml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,9 @@ tests:
278278
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
279279
method: testCreateJobsWithIndexNameOption
280280
issue: https://github.com/elastic/elasticsearch/issues/113528
281-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
282-
method: testPutE5WithTrainedModelAndInference
283-
issue: https://github.com/elastic/elasticsearch/issues/113565
284-
- class: org.elasticsearch.xpack.inference.TextEmbeddingCrudIT
285-
method: testPutE5Small_withPlatformAgnosticVariant
286-
issue: https://github.com/elastic/elasticsearch/issues/113577
281+
- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT
282+
method: test {p0=dot_prefix/10_basic/Deprecated index template with a dot prefix index pattern}
283+
issue: https://github.com/elastic/elasticsearch/issues/113529
287284
- class: org.elasticsearch.xpack.ml.integration.MlJobIT
288285
method: testCantCreateJobWithSameID
289286
issue: https://github.com/elastic/elasticsearch/issues/113581

server/src/main/java/org/elasticsearch/inference/InferenceService.java

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,9 @@ default void init(Client client) {}
3939
* @param modelId Model Id
4040
* @param taskType The model task type
4141
* @param config Configuration options including the secrets
42-
* @param platformArchitectures The Set of platform architectures (OS name and hardware architecture)
43-
* the cluster nodes and models are running on.
4442
* @param parsedModelListener A listener which will handle the resulting model or failure
4543
*/
46-
void parseRequestConfig(
47-
String modelId,
48-
TaskType taskType,
49-
Map<String, Object> config,
50-
Set<String> platformArchitectures,
51-
ActionListener<Model> parsedModelListener
52-
);
44+
void parseRequestConfig(String modelId, TaskType taskType, Map<String, Object> config, ActionListener<Model> parsedModelListener);
5345

5446
/**
5547
* Parse model configuration from {@code config map} from persisted storage and return the parsed {@link Model}. This requires that
@@ -155,17 +147,6 @@ default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
155147
listener.onResponse(true);
156148
}
157149

158-
/**
159-
* Checks if the modelId has been downloaded to the local Elasticsearch cluster using the trained models API
160-
* The default action does nothing except acknowledge the request (false).
161-
* Any internal services should Override this method.
162-
* @param model
163-
* @param listener The listener
164-
*/
165-
default void isModelDownloaded(Model model, ActionListener<Boolean> listener) {
166-
listener.onResponse(false);
167-
};
168-
169150
/**
170151
* Optionally test the new model configuration in the inference service.
171152
* This function should be called when the model is first created, the
@@ -188,14 +169,6 @@ default Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
188169
return model;
189170
}
190171

191-
/**
192-
* Return true if this model is hosted in the local Elasticsearch cluster
193-
* @return True if in cluster
194-
*/
195-
default boolean isInClusterService() {
196-
return false;
197-
}
198-
199172
/**
200173
* Defines the version required across all clusters to use this service
201174
* @return {@link TransportVersion} specifying the version

server/src/main/java/org/elasticsearch/inference/InferenceServiceExtension.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package org.elasticsearch.inference;
1111

1212
import org.elasticsearch.client.internal.Client;
13+
import org.elasticsearch.threadpool.ThreadPool;
1314

1415
import java.util.List;
1516

@@ -20,7 +21,7 @@ public interface InferenceServiceExtension {
2021

2122
List<Factory> getInferenceServiceFactories();
2223

23-
record InferenceServiceFactoryContext(Client client) {}
24+
record InferenceServiceFactoryContext(Client client, ThreadPool threadPool) {}
2425

2526
interface Factory {
2627
/**

x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@ protected String getTestRestCluster() {
5353
@Override
5454
protected Settings restClientSettings() {
5555
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
56-
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
56+
return Settings.builder()
57+
.put(ThreadContext.PREFIX + ".Authorization", token)
58+
.put(CLIENT_SOCKET_TIMEOUT, "120s") // Long timeout for model download
59+
.build();
5760
}
5861

5962
static String mockSparseServiceModelConfig() {

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import java.util.ArrayList;
3939
import java.util.List;
4040
import java.util.Map;
41-
import java.util.Set;
4241

4342
public class TestDenseInferenceServiceExtension implements InferenceServiceExtension {
4443
@Override
@@ -76,7 +75,6 @@ public void parseRequestConfig(
7675
String modelId,
7776
TaskType taskType,
7877
Map<String, Object> config,
79-
Set<String> platformArchitectures,
8078
ActionListener<Model> parsedModelListener
8179
) {
8280
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import java.util.ArrayList;
3535
import java.util.List;
3636
import java.util.Map;
37-
import java.util.Set;
3837

3938
public class TestRerankingServiceExtension implements InferenceServiceExtension {
4039
@Override
@@ -67,7 +66,6 @@ public void parseRequestConfig(
6766
String modelId,
6867
TaskType taskType,
6968
Map<String, Object> config,
70-
Set<String> platformArchitectures,
7169
ActionListener<Model> parsedModelListener
7270
) {
7371
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
import java.util.ArrayList;
3838
import java.util.List;
3939
import java.util.Map;
40-
import java.util.Set;
4140

4241
public class TestSparseInferenceServiceExtension implements InferenceServiceExtension {
4342
@Override
@@ -70,7 +69,6 @@ public void parseRequestConfig(
7069
String modelId,
7170
TaskType taskType,
7271
Map<String, Object> config,
73-
Set<String> platformArchitectures,
7472
ActionListener<Model> parsedModelListener
7573
) {
7674
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ public void parseRequestConfig(
6767
String modelId,
6868
TaskType taskType,
6969
Map<String, Object> config,
70-
Set<String> platformArchitectures,
7170
ActionListener<Model> parsedModelListener
7271
) {
7372
var serviceSettingsMap = (Map<String, Object>) config.remove(ModelConfigurations.SERVICE_SETTINGS);

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/ModelRegistryIT.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.elasticsearch.plugins.Plugin;
2424
import org.elasticsearch.reindex.ReindexPlugin;
2525
import org.elasticsearch.test.ESSingleNodeTestCase;
26+
import org.elasticsearch.threadpool.ThreadPool;
2627
import org.elasticsearch.xcontent.ToXContentObject;
2728
import org.elasticsearch.xcontent.XContentBuilder;
2829
import org.elasticsearch.xpack.inference.InferencePlugin;
@@ -117,7 +118,9 @@ public void testGetModel() throws Exception {
117118

118119
assertEquals(model.getConfigurations().getService(), modelHolder.get().service());
119120

120-
var elserService = new ElserInternalService(new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class)));
121+
var elserService = new ElserInternalService(
122+
new InferenceServiceExtension.InferenceServiceFactoryContext(mock(Client.class), mock(ThreadPool.class))
123+
);
121124
ElserInternalModel roundTripModel = elserService.parsePersistedConfigWithSecrets(
122125
modelHolder.get().inferenceEntityId(),
123126
modelHolder.get().taskType(),

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ public Collection<?> createComponents(PluginServices services) {
206206
);
207207
}
208208

209-
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client());
209+
var factoryContext = new InferenceServiceExtension.InferenceServiceFactoryContext(services.client(), services.threadPool());
210210
// This must be done after the HttpRequestSenderFactory is created so that the services can get the
211211
// reference correctly
212212
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
@@ -299,15 +299,17 @@ public Collection<SystemIndexDescriptor> getSystemIndexDescriptors(Settings sett
299299

300300
@Override
301301
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settingsToUse) {
302-
return List.of(
303-
new ScalingExecutorBuilder(
304-
UTILITY_THREAD_POOL_NAME,
305-
0,
306-
10,
307-
TimeValue.timeValueMinutes(10),
308-
false,
309-
"xpack.inference.utility_thread_pool"
310-
)
302+
return List.of(inferenceUtilityExecutor(settings));
303+
}
304+
305+
public static ExecutorBuilder<?> inferenceUtilityExecutor(Settings settings) {
306+
return new ScalingExecutorBuilder(
307+
UTILITY_THREAD_POOL_NAME,
308+
0,
309+
10,
310+
TimeValue.timeValueMinutes(10),
311+
false,
312+
"xpack.inference.utility_thread_pool"
311313
);
312314
}
313315

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java

Lines changed: 9 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
import org.elasticsearch.ElasticsearchStatusException;
1313
import org.elasticsearch.action.ActionListener;
1414
import org.elasticsearch.action.support.ActionFilters;
15-
import org.elasticsearch.action.support.SubscribableListener;
1615
import org.elasticsearch.action.support.master.TransportMasterNodeAction;
1716
import org.elasticsearch.client.internal.Client;
1817
import org.elasticsearch.cluster.ClusterState;
1918
import org.elasticsearch.cluster.block.ClusterBlockException;
2019
import org.elasticsearch.cluster.block.ClusterBlockLevel;
2120
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
2221
import org.elasticsearch.cluster.service.ClusterService;
23-
import org.elasticsearch.common.settings.ClusterSettings;
2422
import org.elasticsearch.common.settings.Settings;
2523
import org.elasticsearch.common.util.concurrent.EsExecutors;
2624
import org.elasticsearch.common.xcontent.XContentHelper;
@@ -38,17 +36,14 @@
3836
import org.elasticsearch.xcontent.XContentParser;
3937
import org.elasticsearch.xcontent.XContentParserConfiguration;
4038
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
41-
import org.elasticsearch.xpack.core.ml.MachineLearningField;
4239
import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils;
4340
import org.elasticsearch.xpack.core.ml.job.messages.Messages;
4441
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
45-
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
4642
import org.elasticsearch.xpack.inference.InferencePlugin;
4743
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
4844

4945
import java.io.IOException;
5046
import java.util.Map;
51-
import java.util.Set;
5247

5348
import static org.elasticsearch.core.Strings.format;
5449

@@ -156,50 +151,20 @@ protected void masterOperation(
156151
return;
157152
}
158153

159-
if (service.get().isInClusterService()) {
160-
// Find the cluster platform as the service may need that
161-
// information when creating the model
162-
MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(listener.delegateFailureAndWrap((delegate, architectures) -> {
163-
if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) {
164-
parseAndStoreModel(
165-
service.get(),
166-
request.getInferenceEntityId(),
167-
resolvedTaskType,
168-
requestAsMap,
169-
// In Elastic cloud ml nodes run on Linux x86
170-
Set.of("linux-x86_64"),
171-
delegate
172-
);
173-
} else {
174-
// The architecture field could be an empty set, the individual services will need to handle that
175-
parseAndStoreModel(
176-
service.get(),
177-
request.getInferenceEntityId(),
178-
resolvedTaskType,
179-
requestAsMap,
180-
architectures,
181-
delegate
182-
);
183-
}
184-
}), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME));
185-
} else {
186-
// Not an in cluster service, it does not care about the cluster platform
187-
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, Set.of(), listener);
188-
}
154+
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener);
189155
}
190156

191157
private void parseAndStoreModel(
192158
InferenceService service,
193159
String inferenceEntityId,
194160
TaskType taskType,
195161
Map<String, Object> config,
196-
Set<String> platformArchitectures,
197162
ActionListener<PutInferenceModelAction.Response> listener
198163
) {
199164
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
200165
(delegate, verifiedModel) -> modelRegistry.storeModel(
201166
verifiedModel,
202-
ActionListener.wrap(r -> putAndStartModel(service, verifiedModel, delegate), e -> {
167+
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
203168
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
204169
delegate.onFailure(
205170
new ElasticsearchStatusException(
@@ -223,36 +188,15 @@ private void parseAndStoreModel(
223188
}
224189
});
225190

226-
service.parseRequestConfig(inferenceEntityId, taskType, config, platformArchitectures, parsedModelListener);
227-
191+
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
228192
}
229193

230-
private void putAndStartModel(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> finalListener) {
231-
SubscribableListener.<Boolean>newForked(listener -> {
232-
var errorCatchingListener = ActionListener.<Boolean>wrap(listener::onResponse, e -> { listener.onResponse(false); });
233-
service.isModelDownloaded(model, errorCatchingListener);
234-
}).<Boolean>andThen((listener, isDownloaded) -> {
235-
if (isDownloaded == false) {
236-
service.putModel(model, listener);
237-
} else {
238-
listener.onResponse(true);
239-
}
240-
}).<PutInferenceModelAction.Response>andThen((listener, modelDidPut) -> {
241-
if (modelDidPut) {
242-
if (skipValidationAndStart) {
243-
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
244-
} else {
245-
service.start(
246-
model,
247-
listener.delegateFailureAndWrap(
248-
(l3, ok) -> l3.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()))
249-
)
250-
);
251-
}
252-
} else {
253-
logger.warn("Failed to put model [{}]", model.getInferenceEntityId());
254-
}
255-
}).addListener(finalListener);
194+
private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
195+
if (skipValidationAndStart) {
196+
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
197+
} else {
198+
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
199+
}
256200
}
257201

258202
private Map<String, Object> requestToMap(PutInferenceModelAction.Request request) throws IOException {
@@ -276,12 +220,6 @@ protected ClusterBlockException checkBlock(PutInferenceModelAction.Request reque
276220
return state.blocks().globalBlockedException(ClusterBlockLevel.METADATA_WRITE);
277221
}
278222

279-
static boolean clusterIsInElasticCloud(ClusterSettings settings) {
280-
// use a heuristic to determine if in Elastic cloud.
281-
// One such heuristic is where USE_AUTO_MACHINE_MEMORY_PERCENT == true
282-
return settings.get(MachineLearningField.USE_AUTO_MACHINE_MEMORY_PERCENT);
283-
}
284-
285223
/**
286224
* task_type can be specified as either a URL parameter or in the
287225
* request body. Resolve which to use or throw if the settings are

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242

4343
import java.util.List;
4444
import java.util.Map;
45-
import java.util.Set;
4645

4746
import static org.elasticsearch.xpack.core.inference.action.InferenceAction.Request.DEFAULT_TIMEOUT;
4847
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
@@ -69,7 +68,6 @@ public void parseRequestConfig(
6968
String inferenceEntityId,
7069
TaskType taskType,
7170
Map<String, Object> config,
72-
Set<String> platformArchitectures,
7371
ActionListener<Model> parsedModelListener
7472
) {
7573
try {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
import java.io.IOException;
4242
import java.util.List;
4343
import java.util.Map;
44-
import java.util.Set;
4544

4645
import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED;
4746
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
@@ -121,7 +120,6 @@ public void parseRequestConfig(
121120
String modelId,
122121
TaskType taskType,
123122
Map<String, Object> config,
124-
Set<String> platformArchitectures,
125123
ActionListener<Model> parsedModelListener
126124
) {
127125
try {

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333

3434
import java.util.List;
3535
import java.util.Map;
36-
import java.util.Set;
3736

3837
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
3938
import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg;
@@ -58,7 +57,6 @@ public void parseRequestConfig(
5857
String inferenceEntityId,
5958
TaskType taskType,
6059
Map<String, Object> config,
61-
Set<String> platformArchitectures,
6260
ActionListener<Model> parsedModelListener
6361
) {
6462
try {

0 commit comments

Comments
 (0)