Skip to content

[ML] Refactoring inference API non-streaming response validation error object check #126725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,21 @@ public String getRequestType() {
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);

// When the response is streamed the status code could be 200 but the error object will be set
// so we need to check for that specifically
checkForErrorObject(request, result);
if (checkForErrorObject) {
// When the response is streamed the status code could be 200 but the error object will be set
// so we need to check for that specifically
checkForErrorObject(request, result);
}
}

protected abstract void checkForFailureStatusCode(Request request, HttpResult result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@ public interface ResponseHandler {
* @param logger the logger to use for logging
* @param request the original request
* @param result the response from the server
* @param checkForErrorObject if true, the validation function should check for the presence of an error object even if the status code
* indicates a success
* @throws RetryException if the response is invalid
*/
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) throws RetryException;
void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result, boolean checkForErrorObject)
throws RetryException;

/**
* A method for parsing the response from the server.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
} else {
r.readFullResponse(l.delegateFailureAndWrap((ll, httpResult) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, httpResult);
responseHandler.validateResponse(throttlerManager, logger, request, httpResult, true);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, httpResult);
ll.onResponse(inferenceResults);
} catch (Exception e) {
Expand All @@ -134,7 +134,7 @@ public void tryAction(ActionListener<InferenceServiceResults> listener) {
} else {
httpClient.send(request.createHttpRequest(), context, retryableListener.delegateFailure((l, r) -> {
try {
responseHandler.validateResponse(throttlerManager, logger, request, r);
responseHandler.validateResponse(throttlerManager, logger, request, r, false);
InferenceServiceResults inferenceResults = responseHandler.parseResult(request, r);

l.onResponse(inferenceResults);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@ public boolean canHandleStreamingResponses() {
}

@Override
public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
public final void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
// do nothing as the AWS SDK will take care of validation for us
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,13 @@ public AzureMistralOpenAiExternalResponseHandler(
}

@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
checkForFailureStatusCode(request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ public AlwaysRetryingResponseHandler(
this.parseFunction = Objects.requireNonNull(parseFunction);
}

public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
@Override
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
try {
checkForFailureStatusCode(throttlerManager, logger, request, result);
checkForEmptyBody(throttlerManager, logger, request, result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ public void testValidateResponse_DoesNotThrowAnExceptionWhenStatus200_AndNoError
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
);
}

Expand All @@ -85,7 +86,8 @@ public void testValidateResponse_ThrowsErrorWhenMalformedErrorObjectExists() {
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);

Expand Down Expand Up @@ -119,7 +121,8 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8))
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);

Expand All @@ -130,6 +133,32 @@ public void testValidateResponse_ThrowsErrorWhenWellFormedErrorObjectExists() {
);
}

public void testValidateResponse_DoesNot_ThrowErrorWhenWellFormedErrorObjectExists_WhenCheckForErrorIsFalse() {
var handler = getBaseResponseHandler();

String responseJson = """
{
"error": {
"type": "not_found_error",
"message": "a message"
}
}
""";

var response = mock200Response();

var request = mock(Request.class);
when(request.getInferenceEntityId()).thenReturn("abc");

handler.validateResponse(
mock(ThrottlerManager.class),
mock(Logger.class),
request,
new HttpResult(response, responseJson.getBytes(StandardCharsets.UTF_8)),
false
);
}

private static HttpResponse mock200Response() {
int statusCode = 200;
var statusLine = mock(StatusLine.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -76,7 +77,7 @@ public void testSend_CallsSenderAgain_AfterValidateResponseThrowsAnException() t
Answer<InferenceServiceResults> answer = (invocation) -> inferenceResults;

var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any());
doThrow(new RetryException(true, "failed")).doNothing().when(handler).validateResponse(any(), any(), any(), any(), anyBoolean());
// Mockito.thenReturn() does not compile when returning a
// bounded wild card list, thenAnswer must be used instead.
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);
Expand Down Expand Up @@ -351,7 +352,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnException_AfterO
var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doThrow(new IllegalStateException("failed again"))
.when(handler)
.validateResponse(any(), any(), any(), any());
.validateResponse(any(), any(), any(), any(), anyBoolean());
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);

var retrier = createRetrier(sender);
Expand Down Expand Up @@ -388,7 +389,7 @@ public void testSend_ReturnsFailure_WhenValidateResponseThrowsAnElasticsearchExc
var handler = mock(ResponseHandler.class);
doThrow(new RetryException(true, "failed")).doThrow(new RetryException(false, "failed again"))
.when(handler)
.validateResponse(any(), any(), any(), any());
.validateResponse(any(), any(), any(), any(), anyBoolean());
when(handler.parseResult(any(Request.class), any(HttpResult.class))).thenAnswer(answer);

var retrier = createRetrier(httpClient);
Expand Down Expand Up @@ -701,8 +702,13 @@ private ResponseHandler createRetryingResponseHandler() {
// testing failed requests
return new ResponseHandler() {
@Override
public void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result)
throws RetryException {
public void validateResponse(
ThrottlerManager throttlerManager,
Logger logger,
Request request,
HttpResult result,
boolean checkForErrorObject
) throws RetryException {
throw new RetryException(true, new IOException("response handler validate failed as designed"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ private Exception invalidResponse(String responseJson) {
mock(),
mock(),
mockRequest(),
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8))
new HttpResult(mock500Response(), responseJson.getBytes(StandardCharsets.UTF_8)),
true
)
);
}
Expand Down