Skip to content

Commit df6d07f

Browse files
[ML] Refactoring inference API non-streaming response validation error object check (#126725) (#126782)
* Refactoring so that non-streaming does not check for error object * Fixing test
1 parent 06a6f7a commit df6d07f

File tree

9 files changed

+85
-22
lines changed

9 files changed

+85
-22
lines changed

Diff for: x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandler.java

+12-4
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,21 @@ public String getRequestType() {
7676
}
7777

7878
@Override
79-
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
79+
public void validateResponse(
80+
ThrottlerManager throttlerManager,
81+
Logger logger,
82+
Request request,
83+
HttpResult result,
84+
boolean checkForErrorObject
85+
) {
8086
checkForFailureStatusCode(request, result);
8187
checkForEmptyBody(throttlerManager, logger, request, result);
8288

83-
// When the response is streamed the status code could be 200 but the error object will be set
84-
// so we need to check for that specifically
85-
checkForErrorObject(request, result);
89+
if (checkForErrorObject) {
90+
// When the response is streamed the status code could be 200 but the error object will be set
91+
// so we need to check for that specifically
92+
checkForErrorObject(request, result);
93+
}
8694
}
8795

8896
protected abstract void checkForFailureStatusCode(Request request, HttpResult result);

Diff for: x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/ResponseHandler.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,12 @@ public interface ResponseHandler {
2929
* @param logger the logger to use for logging
3030
* @param request the original request
3131
* @param result the response from the server
32+
* @param checkForErrorObject if true, the validation function should check for the presence of an error object even if the status code
33+
* indicates a success
3234
* @throws RetryException if the response is invalid
3335
*/
34-
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException;
36+
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result, boolean checkForErrorObject)
37+
throws RetryException;
3538

3639
/**
3740
* A method for parsing the response from the server.

Diff for: x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSender.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
121121
} else {
122122
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
123123
try {
124-
responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
124+
responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
125125
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
126126
ll.onResponse(inferenceResults);
127127
} catch (Exception e) {
@@ -134,7 +134,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
134134
} else {
135135
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
136136
try {
137-
responseHandler.validateResponse(throttlerManager, logger, request, r);
137+
responseHandler.validateResponse(throttlerManager, logger, request, r, false);
138138
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);
139139

140140
l.onResponse(inferenceResults);

Diff for: x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/response/AmazonBedrockResponseHandler.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@ public boolean canHandleStreamingResponses() {
2222
}
2323

2424
@Override
25-
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
26-
throws RetryException {
25+
public final void validateResponse(
26+
ThrottlerManager throttlerManager,
27+
Logger logger,
28+
Request request,
29+
HttpResult result,
30+
boolean checkForErrorObject
31+
) throws RetryException {
2732
// do nothing as the AWS SDK will take care of validation for us
2833
}
2934
}

Diff for: x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/response/AzureMistralOpenAiExternalResponseHandler.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,13 @@ public AzureMistralOpenAiExternalResponseHandler(
6363
}
6464

6565
@Override
66-
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
67-
throws RetryException {
66+
public void validateResponse(
67+
ThrottlerManager throttlerManager,
68+
Logger logger,
69+
Request request,
70+
HttpResult result,
71+
boolean checkForErrorObject
72+
) throws RetryException {
6873
checkForFailureStatusCode(request, result);
6974
checkForEmptyBody(throttlerManager, logger, request, result);
7075
}

Diff for: x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/AlwaysRetryingResponseHandler.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ public AlwaysRetryingResponseHandler(
3535
this.parseFunction = Objects.requireNonNull(parseFunction);
3636
}
3737

38-
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
39-
throws RetryException {
38+
@Override
39+
public void validateResponse(
40+
ThrottlerManager throttlerManager,
41+
Logger logger,
42+
Request request,
43+
HttpResult result,
44+
boolean checkForErrorObject
45+
) throws RetryException {
4046
try {
4147
checkForFailureStatusCode(throttlerManager, logger, request, result);
4248
checkForEmptyBody(throttlerManager, logger, request, result);

Diff for: x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/BaseResponseHandlerTests.java

+32-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ public void testValidateResponse_DoesNotThrowAnExceptionWhenStatus200_AndNoError
5959
mock(ThrottlerManager.class),
6060
mock(Logger.class),
6161
request,
62-
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
62+
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
63+
true
6364
);
6465
}
6566

@@ -85,7 +86,8 @@ public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() {
8586
mock(ThrottlerManager.class),
8687
mock(Logger.class),
8788
request,
88-
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
89+
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
90+
true
8991
)
9092
);
9193

@@ -119,7 +121,8 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
119121
mock(ThrottlerManager.class),
120122
mock(Logger.class),
121123
request,
122-
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
124+
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
125+
true
123126
)
124127
);
125128

@@ -130,6 +133,32 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
130133
);
131134
}
132135

136+
public void testValidateResponse_DoesNot_ThrowErrorWhenWellFormedErrorObjectExists_WhenCheckForErrorIsFalse() {
137+
var handler = getBaseResponseHandler();
138+
139+
String responseJson = """
140+
{
141+
"error": {
142+
"type": "not_found_error",
143+
"message": "a message"
144+
}
145+
}
146+
""";
147+
148+
var response = mock200Response();
149+
150+
var request = mock(Request.class);
151+
when(request.getInferenceEntityId()).thenReturn("abc");
152+
153+
handler.validateResponse(
154+
mock(ThrottlerManager.class),
155+
mock(Logger.class),
156+
request,
157+
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
158+
false
159+
);
160+
}
161+
133162
private static HttpResponse mock200Response() {
134163
int statusCode = 200;
135164
var statusLine = mock(StatusLine.class);

Diff for: x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/retry/RetryingHttpSenderTests.java

+11-5
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import static org.hamcrest.Matchers.is;
4343
import static org.hamcrest.Matchers.sameInstance;
4444
import static org.mockito.ArgumentMatchers.any;
45+
import static org.mockito.ArgumentMatchers.anyBoolean;
4546
import static org.mockito.Mockito.doAnswer;
4647
import static org.mockito.Mockito.doThrow;
4748
import static org.mockito.Mockito.mock;
@@ -76,7 +77,7 @@ public void testSend_CallsSenderAgain_AfterValidateResponseThrowsAnException() t
7677
Answer<InferenceServiceResults> answer = (invocation) -> inferenceResults;
7778

7879
var handler = mock(ResponseHandler.class);
79-
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any());
80+
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any(), anyBoolean());
8081
// Mockito.thenReturn() does not compile when returning a
8182
// bounded wild card list, thenAnswer must be used instead.
8283
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
@@ -351,7 +352,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnException_AfterO
351352
var handler = mock(ResponseHandler.class);
352353
doThrow(new RetryException(true, "failed")).doThrow(new IllegalStateException("failed again"))
353354
.when(handler)
354-
.validateResponse(any(), any(), any(), any());
355+
.validateResponse(any(), any(), any(), any(), anyBoolean());
355356
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
356357

357358
var retrier = createRetrier(sender);
@@ -388,7 +389,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnElasticsearchExc
388389
var handler = mock(ResponseHandler.class);
389390
doThrow(new RetryException(true, "failed")).doThrow(new RetryException(false, "failed again"))
390391
.when(handler)
391-
.validateResponse(any(), any(), any(), any());
392+
.validateResponse(any(), any(), any(), any(), anyBoolean());
392393
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
393394

394395
var retrier = createRetrier(httpClient);
@@ -701,8 +702,13 @@ private ResponseHandler createRetryingResponseHandler() {
701702
// testing failed requests
702703
return new ResponseHandler() {
703704
@Override
704-
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
705-
throws RetryException {
705+
public void validateResponse(
706+
ThrottlerManager throttlerManager,
707+
Logger logger,
708+
Request request,
709+
HttpResult result,
710+
boolean checkForErrorObject
711+
) throws RetryException {
706712
throw new RetryException(true, new IOException("response handler validate failed as designed"));
707713
}
708714

Diff for: x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiUnifiedChatCompletionResponseHandlerTests.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ private Exception invalidResponse(String responseJson) {
9696
mock(),
9797
mock(),
9898
mockRequest(),
99-
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
99+
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
100+
true
100101
)
101102
);
102103
}

0 commit comments

Comments
 (0)