Skip to content

Commit d00f414

Browse files
authored
[ML] Make warnings from inference errors (#81735)
Consistently using exceptions for errors instead of WarningInferenceResults to simplify debugging/triaging
1 parent e0c88c2 commit d00f414

File tree

9 files changed

+51
-71
lines changed

9 files changed

+51
-71
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -527,8 +527,9 @@ public void testTruncation() throws IOException {
527527
startDeployment(modelId, AllocationStatus.State.FULLY_ALLOCATED.toString());
528528

529529
String input = "once twice thrice";
530+
var e = expectThrows(ResponseException.class, () -> EntityUtils.toString(infer("once twice thrice", modelId).getEntity()));
530531
assertThat(
531-
EntityUtils.toString(infer("once twice thrice", modelId).getEntity()),
532+
e.getMessage(),
532533
containsString("Input too large. The tokenized input length [3] exceeds the maximum sequence length [2]")
533534
);
534535

@@ -637,7 +638,8 @@ public void testPipelineWithBadProcessor() throws IOException {
637638
""";
638639

639640
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
640-
assertThat(response, containsString("warning"));
641+
assertThat(response, containsString("no value could be found for input field [input]"));
642+
assertThat(response, containsString("status_exception"));
641643
}
642644

643645
public void testDeleteModelWithDeploymentUsedByIngestProcessor() throws IOException {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.logging.log4j.message.ParameterizedMessage;
1313
import org.apache.lucene.util.SetOnce;
14-
import org.elasticsearch.ElasticsearchException;
1514
import org.elasticsearch.ElasticsearchStatusException;
1615
import org.elasticsearch.ResourceNotFoundException;
1716
import org.elasticsearch.action.ActionListener;
@@ -34,7 +33,6 @@
3433
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3534
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
3635
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
37-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3836
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3937
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
4038
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
@@ -374,32 +372,17 @@ protected void doRun() throws Exception {
374372
processContext,
375373
request.tokenization,
376374
processor.getResultProcessor((NlpConfig) config),
377-
ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this))
375+
this
378376
),
379377
this::onFailure
380378
)
381379
);
382380
processContext.process.get().writeInferenceRequest(request.processInput);
383381
} catch (IOException e) {
384382
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
385-
handleFailure(ExceptionsHelper.serverError("error writing to process", e), this);
383+
onFailure(ExceptionsHelper.serverError("error writing to process", e));
386384
} catch (Exception e) {
387-
handleFailure(e, this);
388-
}
389-
}
390-
391-
private static void handleFailure(Exception e, ActionListener<InferenceResults> listener) {
392-
Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e);
393-
if (unwrapped instanceof ElasticsearchException ex) {
394-
if (ex.status() == RestStatus.BAD_REQUEST) {
395-
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
396-
} else {
397-
listener.onFailure(ex);
398-
}
399-
} else if (unwrapped instanceof IllegalArgumentException) {
400-
listener.onResponse(new WarningInferenceResults(e.getMessage()));
401-
} else {
402-
listener.onFailure(e);
385+
onFailure(e);
403386
}
404387
}
405388

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
1012
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1113
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1415
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1617
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -86,7 +87,7 @@ static InferenceResults processResult(
8687
String resultsField
8788
) {
8889
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
89-
return new WarningInferenceResults("No valid tokens for inference");
90+
throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR);
9091
}
9192

9293
if (tokenizer.getMaskTokenId().isEmpty()) {
@@ -105,8 +106,9 @@ static InferenceResults processResult(
105106
}
106107
}
107108
if (maskTokenIndex == -1) {
108-
return new WarningInferenceResults(
109+
throw new ElasticsearchStatusException(
109110
"mask token id [{}] not found in the tokenization {}",
111+
RestStatus.INTERNAL_SERVER_ERROR,
110112
maskTokenId,
111113
List.of(tokenization.getTokenizations().get(0).getTokenIds())
112114
);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
1213
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.rest.RestStatus;
1315
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1416
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
15-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1617
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1819
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.DelimitedToken;
@@ -195,7 +196,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor {
195196
@Override
196197
public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
197198
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokenIds().length == 0) {
198-
return new WarningInferenceResults("no valid tokens to build result");
199+
throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR);
199200
}
200201
// TODO - process all results in the batch
201202

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
1012
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1113
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1415
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
1617
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.NlpTokenizer;
@@ -39,11 +40,6 @@ public class TextClassificationProcessor implements NlpTask.Processor {
3940
// negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
4041
// This is a nice way of setting the value
4142
this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
42-
validate();
43-
}
44-
45-
private void validate() {
46-
// validation occurs in TextClassificationConfig
4743
}
4844

4945
@Override
@@ -87,14 +83,15 @@ static InferenceResults processResult(
8783
String resultsField
8884
) {
8985
if (pyTorchResult.getInferenceResult().length < 1) {
90-
return new WarningInferenceResults("Text classification result has no data");
86+
throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
9187
}
9288

9389
// TODO only the first entry in the batch result is verified and
9490
// checked. Implement for all in batch
9591
if (pyTorchResult.getInferenceResult()[0][0].length != labels.size()) {
96-
return new WarningInferenceResults(
92+
throw new ElasticsearchStatusException(
9793
"Expected exactly [{}] values in text classification result; got [{}]",
94+
RestStatus.INTERNAL_SERVER_ERROR,
9895
labels.size(),
9996
pyTorchResult.getInferenceResult()[0][0].length
10097
);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.logging.LoggerMessageFormat;
12+
import org.elasticsearch.rest.RestStatus;
1113
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
1315
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1617
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
@@ -114,7 +115,7 @@ static class RequestBuilder implements NlpTask.RequestBuilder {
114115
@Override
115116
public NlpTask.Request buildRequest(List<String> inputs, String requestId, Tokenization.Truncate truncate) throws IOException {
116117
if (inputs.size() > 1) {
117-
throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
118+
throw ExceptionsHelper.badRequestException("Unable to do zero-shot classification on more than one text input at a time");
118119
}
119120
List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
120121
for (String label : labels) {
@@ -148,13 +149,14 @@ static class ResultProcessor implements NlpTask.ResultProcessor {
148149
@Override
149150
public InferenceResults processResult(TokenizationResult tokenization, PyTorchInferenceResult pyTorchResult) {
150151
if (pyTorchResult.getInferenceResult().length < 1) {
151-
return new WarningInferenceResults("Zero shot classification result has no data");
152+
throw new ElasticsearchStatusException("Zero shot classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
152153
}
153154
// TODO only the first entry in the batch result is verified and
154155
// checked. Implement for all in batch
155156
if (pyTorchResult.getInferenceResult()[0].length != labels.length) {
156-
return new WarningInferenceResults(
157+
throw new ElasticsearchStatusException(
157158
"Expected exactly [{}] values in zero shot classification result; got [{}]",
159+
RestStatus.INTERNAL_SERVER_ERROR,
158160
labels.length,
159161
pyTorchResult.getInferenceResult().length
160162
);
@@ -165,8 +167,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
165167
int v = 0;
166168
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
167169
if (vals.length != 3) {
168-
return new WarningInferenceResults(
170+
throw new ElasticsearchStatusException(
169171
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
172+
RestStatus.INTERNAL_SERVER_ERROR,
170173
3,
171174
vals.length
172175
);
@@ -181,8 +184,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchIn
181184
int v = 0;
182185
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
183186
if (vals.length != 3) {
184-
return new WarningInferenceResults(
187+
throw new ElasticsearchStatusException(
185188
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
189+
RestStatus.INTERNAL_SERVER_ERROR,
186190
3,
187191
vals.length
188192
);

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.test.ESTestCase;
1212
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1313
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1514
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1615
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
1716
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BasicTokenizer;
@@ -28,7 +27,6 @@
2827
import static org.hamcrest.Matchers.containsString;
2928
import static org.hamcrest.Matchers.equalTo;
3029
import static org.hamcrest.Matchers.hasSize;
31-
import static org.hamcrest.Matchers.instanceOf;
3230
import static org.mockito.Mockito.mock;
3331
import static org.mockito.Mockito.when;
3432

@@ -90,9 +88,9 @@ public void testProcessResults_GivenMissingTokens() {
9088
tokenization.addTokenization("", false, Collections.emptyList(), new int[] {}, new int[] {});
9189

9290
PyTorchInferenceResult pyTorchResult = new PyTorchInferenceResult("1", new double[][][] { { {} } }, 0L, null);
93-
assertThat(
94-
FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10)),
95-
instanceOf(WarningInferenceResults.class)
91+
expectThrows(
92+
ElasticsearchStatusException.class,
93+
() -> FillMaskProcessor.processResult(tokenization, pyTorchResult, tokenizer, 5, randomAlphaOfLength(10))
9694
);
9795
}
9896

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1414
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
1515
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
@@ -91,10 +91,12 @@ public void testValidate_NotAEntityLabel() {
9191
public void testProcessResults_GivenNoTokens() {
9292
NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
9393
TokenizationResult tokenization = tokenize(List.of(BertTokenizer.PAD_TOKEN, BertTokenizer.UNKNOWN_TOKEN), "");
94-
assertThat(
95-
processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null)),
96-
instanceOf(WarningInferenceResults.class)
94+
95+
var e = expectThrows(
96+
ElasticsearchStatusException.class,
97+
() -> processor.processResult(tokenization, new PyTorchInferenceResult("test", null, 0L, null))
9798
);
99+
assertThat(e, instanceOf(ElasticsearchStatusException.class));
98100
}
99101

100102
public void testProcessResults() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.xcontent.XContentHelper;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xcontent.XContentType;
13-
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1514
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
1615
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
1716
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
@@ -25,6 +24,7 @@
2524
import java.util.Map;
2625

2726
import static org.elasticsearch.xpack.ml.inference.nlp.tokenizers.BertTokenizerTests.TEST_CASED_VOCAB;
27+
import static org.hamcrest.Matchers.containsString;
2828
import static org.hamcrest.Matchers.hasSize;
2929
import static org.hamcrest.Matchers.instanceOf;
3030

@@ -33,30 +33,21 @@ public class TextClassificationProcessorTests extends ESTestCase {
3333
public void testInvalidResult() {
3434
{
3535
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] {}, 0L, null);
36-
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
37-
null,
38-
torchResult,
39-
randomInt(),
40-
List.of("a", "b"),
41-
randomAlphaOfLength(10)
36+
var e = expectThrows(
37+
ElasticsearchStatusException.class,
38+
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
4239
);
43-
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
44-
assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
40+
assertThat(e, instanceOf(ElasticsearchStatusException.class));
41+
assertThat(e.getMessage(), containsString("Text classification result has no data"));
4542
}
4643
{
4744
PyTorchInferenceResult torchResult = new PyTorchInferenceResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
48-
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
49-
null,
50-
torchResult,
51-
randomInt(),
52-
List.of("a", "b"),
53-
randomAlphaOfLength(10)
54-
);
55-
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
56-
assertEquals(
57-
"Expected exactly [2] values in text classification result; got [1]",
58-
((WarningInferenceResults) inferenceResults).getWarning()
45+
var e = expectThrows(
46+
ElasticsearchStatusException.class,
47+
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
5948
);
49+
assertThat(e, instanceOf(ElasticsearchStatusException.class));
50+
assertThat(e.getMessage(), containsString("Expected exactly [2] values in text classification result; got [1]"));
6051
}
6152
}
6253

0 commit comments

Comments
 (0)