Skip to content

Commit 5082f32

Browse files
Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs (#125023) (#125414)
* Fix AlibabaCloudSearchCompletionAction not accepting ChatCompletionInputs * Update docs/changelog/125023.yaml * Fix unit tests Co-authored-by: Elastic Machine <[email protected]>
1 parent 1bce184 commit 5082f32

File tree

3 files changed

+168
-15
lines changed

3 files changed

+168
-15
lines changed

docs/changelog/125023.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125023
2+
summary: Fix `AlibabaCloudSearchCompletionAction` not accepting `ChatCompletionInputs`
3+
area: Machine Learning
4+
type: bug
5+
issues: []

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/alibabacloudsearch/AlibabaCloudSearchCompletionAction.java

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,18 @@
1414
import org.elasticsearch.action.ActionListener;
1515
import org.elasticsearch.core.TimeValue;
1616
import org.elasticsearch.inference.InferenceServiceResults;
17-
import org.elasticsearch.inference.TaskType;
1817
import org.elasticsearch.rest.RestStatus;
1918
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
2019
import org.elasticsearch.xpack.inference.external.alibabacloudsearch.AlibabaCloudSearchAccount;
2120
import org.elasticsearch.xpack.inference.external.http.sender.AlibabaCloudSearchCompletionRequestManager;
22-
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
21+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
2322
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
2423
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
2524
import org.elasticsearch.xpack.inference.services.ServiceComponents;
2625
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModel;
2726

2827
import java.util.Objects;
2928

30-
import static org.elasticsearch.core.Strings.format;
3129
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
3230
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError;
3331
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException;
@@ -51,18 +49,8 @@ public AlibabaCloudSearchCompletionAction(Sender sender, AlibabaCloudSearchCompl
5149

5250
@Override
5351
public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener<InferenceServiceResults> listener) {
54-
if (inferenceInputs instanceof DocumentsOnlyInput == false) {
55-
listener.onFailure(
56-
new ElasticsearchStatusException(
57-
format("Invalid inference input type, task type [%s] do not support Field [query]", TaskType.COMPLETION),
58-
RestStatus.INTERNAL_SERVER_ERROR
59-
)
60-
);
61-
return;
62-
}
63-
64-
var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs;
65-
if (docsOnlyInput.getInputs().size() % 2 == 0) {
52+
var completionInput = inferenceInputs.castTo(ChatCompletionInput.class);
53+
if (completionInput.getInputs().size() % 2 == 0) {
6654
listener.onFailure(
6755
new ElasticsearchStatusException(
6856
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.inference.external.action.alibabacloudsearch;
9+
10+
import org.elasticsearch.ElasticsearchException;
11+
import org.elasticsearch.ElasticsearchStatusException;
12+
import org.elasticsearch.action.ActionListener;
13+
import org.elasticsearch.action.support.PlainActionFuture;
14+
import org.elasticsearch.common.settings.Settings;
15+
import org.elasticsearch.core.TimeValue;
16+
import org.elasticsearch.inference.InferenceServiceResults;
17+
import org.elasticsearch.inference.TaskType;
18+
import org.elasticsearch.rest.RestStatus;
19+
import org.elasticsearch.test.ESTestCase;
20+
import org.elasticsearch.test.http.MockWebServer;
21+
import org.elasticsearch.threadpool.ThreadPool;
22+
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
23+
import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults;
24+
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
25+
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
26+
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
27+
import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput;
28+
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
29+
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
30+
import org.elasticsearch.xpack.inference.services.ServiceComponentsTests;
31+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionModelTests;
32+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionServiceSettingsTests;
33+
import org.elasticsearch.xpack.inference.services.alibabacloudsearch.completion.AlibabaCloudSearchCompletionTaskSettingsTests;
34+
import org.junit.After;
35+
import org.junit.Before;
36+
37+
import java.io.IOException;
38+
import java.util.List;
39+
import java.util.concurrent.TimeUnit;
40+
41+
import static org.apache.lucene.tests.util.LuceneTestCase.expectThrows;
42+
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
43+
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
44+
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
45+
import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
46+
import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap;
47+
import static org.hamcrest.MatcherAssert.assertThat;
48+
import static org.hamcrest.Matchers.is;
49+
import static org.mockito.ArgumentMatchers.any;
50+
import static org.mockito.Mockito.doAnswer;
51+
import static org.mockito.Mockito.doThrow;
52+
import static org.mockito.Mockito.mock;
53+
54+
public class AlibabaCloudSearchCompletionActionTests extends ESTestCase {
55+
56+
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
57+
private final MockWebServer webServer = new MockWebServer();
58+
private ThreadPool threadPool;
59+
private HttpClientManager clientManager;
60+
61+
@Before
62+
public void init() throws IOException {
63+
webServer.start();
64+
threadPool = createThreadPool(inferenceUtilityPool());
65+
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
66+
}
67+
68+
@After
69+
public void shutdown() throws IOException {
70+
clientManager.close();
71+
terminate(threadPool);
72+
webServer.close();
73+
}
74+
75+
public void testExecute_Success() {
76+
var sender = mock(Sender.class);
77+
78+
var resultString = randomAlphaOfLength(100);
79+
doAnswer(invocation -> {
80+
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
81+
listener.onResponse(new ChatCompletionResults(List.of(new ChatCompletionResults.Result(resultString))));
82+
83+
return Void.TYPE;
84+
}).when(sender).send(any(), any(), any(), any());
85+
var action = createAction(threadPool, sender);
86+
87+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
88+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
89+
90+
var result = listener.actionGet(TIMEOUT);
91+
assertThat(result.asMap(), is(buildExpectationCompletion(List.of(resultString))));
92+
}
93+
94+
public void testExecute_ListenerThrowsElasticsearchException_WhenSenderThrowsElasticsearchException() {
95+
var sender = mock(Sender.class);
96+
doThrow(new ElasticsearchException("error")).when(sender).send(any(), any(), any(), any());
97+
var action = createAction(threadPool, sender);
98+
99+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
100+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
101+
102+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
103+
assertThat(thrownException.getMessage(), is("error"));
104+
}
105+
106+
public void testExecute_ListenerThrowsInternalServerError_WhenSenderThrowsException() {
107+
var sender = mock(Sender.class);
108+
doThrow(new RuntimeException("error")).when(sender).send(any(), any(), any(), any());
109+
var action = createAction(threadPool, sender);
110+
111+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
112+
action.execute(new ChatCompletionInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
113+
114+
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
115+
assertThat(thrownException.getMessage(), is(constructFailedToSendRequestMessage("AlibabaCloud Search completion")));
116+
}
117+
118+
public void testExecute_ThrowsIllegalArgumentException_WhenInputIsNotChatCompletionInput() {
119+
var action = createAction(threadPool, mock(Sender.class));
120+
121+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
122+
assertThrows(IllegalArgumentException.class, () -> {
123+
action.execute(new DocumentsOnlyInput(List.of(randomAlphaOfLength(10))), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
124+
});
125+
}
126+
127+
public void testExecute_ListenerThrowsElasticsearchStatusException_WhenInputSizeIsEven() {
128+
var action = createAction(threadPool, mock(Sender.class));
129+
130+
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
131+
action.execute(
132+
new ChatCompletionInput(List.of(randomAlphaOfLength(10), randomAlphaOfLength(10))),
133+
InferenceAction.Request.DEFAULT_TIMEOUT,
134+
listener
135+
);
136+
137+
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
138+
assertThat(
139+
thrownException.getMessage(),
140+
is(
141+
"Alibaba Completion's inputs must be an odd number. The last input is the current query, "
142+
+ "all preceding inputs are the completion history as pairs of user input and the assistant's response."
143+
)
144+
);
145+
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
146+
}
147+
148+
private ExecutableAction createAction(ThreadPool threadPool, Sender sender) {
149+
var model = AlibabaCloudSearchCompletionModelTests.createModel(
150+
"completion_test",
151+
TaskType.COMPLETION,
152+
AlibabaCloudSearchCompletionServiceSettingsTests.getServiceSettingsMap("completion_test", "host", "default"),
153+
AlibabaCloudSearchCompletionTaskSettingsTests.getTaskSettingsMap(null),
154+
getSecretSettingsMap("secret")
155+
);
156+
157+
var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool);
158+
return new AlibabaCloudSearchCompletionAction(sender, model, serviceComponents);
159+
}
160+
}

0 commit comments

Comments
 (0)