Skip to content

Commit cc5d979

Browse files
prwhelanalbertzaharovits
authored andcommitted
[ML] Consolidate ExecutableActions (#110806)
- Created two new ExecutableActions, SenderExecutableAction and SingleInputSenderExecutableAction - Most chat completions are migrated over the SingleInputSenderExecutableAction - Every other Action are migrated over to SenderExecutableAction - RequestManagers and Error Message construction are migrated into the ActionCreator classes. Relate #110805
1 parent 088a187 commit cc5d979

File tree

44 files changed

+561
-1066
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+561
-1066
lines changed

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/ActionUtils.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ public static ActionListener<InferenceServiceResults> wrapFailuresInElasticsearc
2424
String errorMessage,
2525
ActionListener<InferenceServiceResults> listener
2626
) {
27-
return ActionListener.wrap(listener::onResponse, e -> {
27+
return listener.delegateResponse((l, e) -> {
2828
var unwrappedException = ExceptionsHelper.unwrapCause(e);
2929

3030
if (unwrappedException instanceof ElasticsearchException esException) {
31-
listener.onFailure(esException);
31+
l.onFailure(esException);
3232
} else {
33-
listener.onFailure(createInternalServerError(unwrappedException, errorMessage));
33+
l.onFailure(createInternalServerError(unwrappedException, errorMessage));
3434
}
3535
});
3636
}
Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,44 +5,38 @@
55
* 2.0.
66
*/
77

8-
package org.elasticsearch.xpack.inference.external.action.amazonbedrock;
8+
package org.elasticsearch.xpack.inference.external.action;
99

10-
import org.elasticsearch.ElasticsearchException;
1110
import org.elasticsearch.action.ActionListener;
1211
import org.elasticsearch.core.TimeValue;
1312
import org.elasticsearch.inference.InferenceServiceResults;
14-
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
1513
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
1614
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
1715
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1816

1917
import java.util.Objects;
2018

21-
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
2219
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
2320

24-
public class AmazonBedrockEmbeddingsAction implements ExecutableAction {
21+
public class SenderExecutableAction implements ExecutableAction {
2522

2623
private final Sender sender;
2724
private final RequestManager requestManager;
28-
private final String errorMessage;
25+
private final String failedToSendRequestErrorMessage;
2926

30-
public AmazonBedrockEmbeddingsAction(Sender sender, RequestManager requestManager, String errorMessage) {
27+
public SenderExecutableAction(Sender sender, RequestManager requestManager, String failedToSendRequestErrorMessage) {
3128
this.sender = Objects.requireNonNull(sender);
3229
this.requestManager = Objects.requireNonNull(requestManager);
33-
this.errorMessage = Objects.requireNonNull(errorMessage);
30+
this.failedToSendRequestErrorMessage = Objects.requireNonNull(failedToSendRequestErrorMessage);
3431
}
3532

3633
@Override
3734
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
35+
var wrappedListener = wrapFailuresInElasticsearchException(failedToSendRequestErrorMessage, listener);
3836
try {
39-
ActionListener<InferenceServiceResults> wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener);
40-
4137
sender.send(requestManager, inferenceInputs, timeout, wrappedListener);
42-
} catch (ElasticsearchException e) {
43-
listener.onFailure(e);
4438
} catch (Exception e) {
45-
listener.onFailure(createInternalServerError(e, errorMessage));
39+
wrappedListener.onFailure(e);
4640
}
4741
}
4842
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action;
9+
10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.action.ActionListener;
12+
import org.elasticsearch.core.TimeValue;
13+
import org.elasticsearch.inference.InferenceServiceResults;
14+
import org.elasticsearch.rest.RestStatus;
15+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
16+
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
17+
import org.elasticsearch.xpack.inference.external.http.sender.RequestManager;
18+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
19+
20+
import java.util.Objects;
21+
22+
public class SingleInputSenderExecutableAction extends SenderExecutableAction {
23+
private final String requestTypeForInputValidationError;
24+
25+
public SingleInputSenderExecutableAction(
26+
Sender sender,
27+
RequestManager requestManager,
28+
String failedToSendRequestErrorMessage,
29+
String requestTypeForInputValidationError
30+
) {
31+
super(sender, requestManager, failedToSendRequestErrorMessage);
32+
this.requestTypeForInputValidationError = Objects.requireNonNull(requestTypeForInputValidationError);
33+
}
34+
35+
@Override
36+
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
37+
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
38+
listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR));
39+
return;
40+
}
41+
42+
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
43+
if (docsOnlyInput.getInputs().size() > 1) {
44+
listener.onFailure(
45+
new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST)
46+
);
47+
return;
48+
}
49+
50+
super.execute(inferenceInputs, timeout, listener);
51+
}
52+
53+
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import org.elasticsearch.core.Nullable;
1111
import org.elasticsearch.core.TimeValue;
1212
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
13+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1314
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager;
1415
import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager;
1516
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -43,14 +44,14 @@ public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map
4344
timeout
4445
);
4546
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock embeddings");
46-
return new AmazonBedrockEmbeddingsAction(sender, requestManager, errorMessage);
47+
return new SenderExecutableAction(sender, requestManager, errorMessage);
4748
}
4849

4950
@Override
5051
public ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map<String, Object> taskSettings) {
5152
var overriddenModel = AmazonBedrockChatCompletionModel.of(completionModel, taskSettings);
5253
var requestManager = new AmazonBedrockChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool(), timeout);
5354
var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock completion");
54-
return new AmazonBedrockChatCompletionAction(sender, requestManager, errorMessage);
55+
return new SenderExecutableAction(sender, requestManager, errorMessage);
5556
}
5657
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java

Lines changed: 0 additions & 47 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreator.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,22 @@
88
package org.elasticsearch.xpack.inference.external.action.anthropic;
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
11+
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
12+
import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager;
1113
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
1214
import org.elasticsearch.xpack.inference.services.ServiceComponents;
1315
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionModel;
1416

1517
import java.util.Map;
1618
import java.util.Objects;
1719

20+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
21+
1822
/**
1923
* Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the anthropic model type.
2024
*/
2125
public class AnthropicActionCreator implements AnthropicActionVisitor {
26+
private static final String ERROR_PREFIX = "Anthropic chat completions";
2227
private final Sender sender;
2328
private final ServiceComponents serviceComponents;
2429

@@ -30,7 +35,8 @@ public AnthropicActionCreator(Sender sender, ServiceComponents serviceComponents
3035
@Override
3136
public ExecutableAction create(AnthropicChatCompletionModel model, Map<String, Object> taskSettings) {
3237
var overriddenModel = AnthropicChatCompletionModel.of(model, taskSettings);
33-
34-
return new AnthropicChatCompletionAction(sender, overriddenModel, serviceComponents);
38+
var requestCreator = AnthropicCompletionRequestManager.of(overriddenModel, serviceComponents.threadPool());
39+
var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getUri(), ERROR_PREFIX);
40+
return new SingleInputSenderExecutableAction(sender, requestCreator, errorMessage, ERROR_PREFIX);
3541
}
3642
}

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionAction.java

Lines changed: 0 additions & 68 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioAction.java

Lines changed: 0 additions & 45 deletions
This file was deleted.

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionCreator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.external.action.azureaistudio;
99

1010
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
11+
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
1112
import org.elasticsearch.xpack.inference.external.http.sender.AzureAiStudioChatCompletionRequestManager;
1213
import org.elasticsearch.xpack.inference.external.http.sender.AzureAiStudioEmbeddingsRequestManager;
1314
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@@ -34,7 +35,7 @@ public ExecutableAction create(AzureAiStudioChatCompletionModel completionModel,
3435
var overriddenModel = AzureAiStudioChatCompletionModel.of(completionModel, taskSettings);
3536
var requestManager = new AzureAiStudioChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool());
3637
var errorMessage = constructFailedToSendRequestMessage(completionModel.uri(), "Azure AI Studio completion");
37-
return new AzureAiStudioAction(sender, requestManager, errorMessage);
38+
return new SenderExecutableAction(sender, requestManager, errorMessage);
3839
}
3940

4041
@Override
@@ -46,6 +47,6 @@ public ExecutableAction create(AzureAiStudioEmbeddingsModel embeddingsModel, Map
4647
serviceComponents.threadPool()
4748
);
4849
var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Azure AI Studio embeddings");
49-
return new AzureAiStudioAction(sender, requestManager, errorMessage);
50+
return new SenderExecutableAction(sender, requestManager, errorMessage);
5051
}
5152
}

0 commit comments

Comments
 (0)