Skip to content

Commit 1017fe3

Browse files
committed
[ML] Make warnings from inference errors (elastic#81735)
Consistently using exceptions for errors instead of WarningInferenceResults to simplify debugging/triaging # Conflicts: # x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java # x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java # x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java # x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java # x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java # x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java
1 parent 01e9e8c commit 1017fe3

File tree

9 files changed

+45
-70
lines changed

9 files changed

+45
-70
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelIT.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,8 @@ public void testPipelineWithBadProcessor() throws IOException {
605605
""";
606606

607607
response = EntityUtils.toString(client().performRequest(simulateRequest(source)).getEntity());
608-
assertThat(response, containsString("warning"));
608+
assertThat(response, containsString("no value could be found for input field [input]"));
609+
assertThat(response, containsString("status_exception"));
609610
}
610611

611612
private int sumInferenceCountOnNodes(List<Map<String, Object>> nodes) {

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/deployment/DeploymentManager.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.logging.log4j.Logger;
1212
import org.apache.logging.log4j.message.ParameterizedMessage;
1313
import org.apache.lucene.util.SetOnce;
14-
import org.elasticsearch.ElasticsearchException;
1514
import org.elasticsearch.ElasticsearchStatusException;
1615
import org.elasticsearch.ResourceNotFoundException;
1716
import org.elasticsearch.action.ActionListener;
@@ -34,7 +33,6 @@
3433
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
3534
import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
3635
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
37-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
3836
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.IndexLocation;
3937
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig;
4038
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
@@ -369,32 +367,17 @@ protected void doRun() throws Exception {
369367
processContext,
370368
request.tokenization,
371369
processor.getResultProcessor((NlpConfig) config),
372-
ActionListener.wrap(this::onSuccess, f -> handleFailure(f, this))
370+
this
373371
),
374372
this::onFailure
375373
)
376374
);
377375
processContext.process.get().writeInferenceRequest(request.processInput);
378376
} catch (IOException e) {
379377
logger.error(new ParameterizedMessage("[{}] error writing to process", processContext.task.getModelId()), e);
380-
handleFailure(ExceptionsHelper.serverError("error writing to process", e), this);
378+
onFailure(ExceptionsHelper.serverError("error writing to process", e));
381379
} catch (Exception e) {
382-
handleFailure(e, this);
383-
}
384-
}
385-
386-
private static void handleFailure(Exception e, ActionListener<InferenceResults> listener) {
387-
Throwable unwrapped = org.elasticsearch.ExceptionsHelper.unwrapCause(e);
388-
if (unwrapped instanceof ElasticsearchException ex) {
389-
if (ex.status() == RestStatus.BAD_REQUEST) {
390-
listener.onResponse(new WarningInferenceResults(ex.getMessage()));
391-
} else {
392-
listener.onFailure(ex);
393-
}
394-
} else if (unwrapped instanceof IllegalArgumentException) {
395-
listener.onResponse(new WarningInferenceResults(e.getMessage()));
396-
} else {
397-
listener.onFailure(e);
380+
onFailure(e);
398381
}
399382
}
400383

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessor.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
1012
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1113
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1415
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1617
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -79,10 +80,11 @@ static InferenceResults processResult(
7980
String resultsField
8081
) {
8182
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
82-
return new WarningInferenceResults("No valid tokens for inference");
83+
throw new ElasticsearchStatusException("tokenization is empty", RestStatus.INTERNAL_SERVER_ERROR);
8384
}
8485

8586
int maskTokenIndex = Arrays.asList(tokenization.getTokenizations().get(0).getTokens()).indexOf(BertTokenizer.MASK_TOKEN);
87+
8688
// TODO - process all results in the batch
8789
double[] normalizedScores = NlpHelpers.convertToProbabilitiesBySoftMax(pyTorchResult.getInferenceResult()[0][maskTokenIndex]);
8890

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessor.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.common.io.stream.StreamOutput;
1213
import org.elasticsearch.common.io.stream.Writeable;
14+
import org.elasticsearch.rest.RestStatus;
1315
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1416
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
15-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1617
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
1718
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1819
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -194,7 +195,7 @@ static class NerResultProcessor implements NlpTask.ResultProcessor {
194195
@Override
195196
public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
196197
if (tokenization.getTokenizations().isEmpty() || tokenization.getTokenizations().get(0).getTokens().length == 0) {
197-
return new WarningInferenceResults("no valid tokens to build result");
198+
throw new ElasticsearchStatusException("no valid tokenization to build result", RestStatus.INTERNAL_SERVER_ERROR);
198199
}
199200
// TODO - process all results in the batch
200201

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessor.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.rest.RestStatus;
1012
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1113
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1415
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
1617
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -39,11 +40,6 @@ public class TextClassificationProcessor implements NlpTask.Processor {
3940
// negative values are a special case of asking for ALL classes. Since we require the output size to equal the classLabel size
4041
// This is a nice way of setting the value
4142
this.numTopClasses = config.getNumTopClasses() < 0 ? this.classLabels.length : config.getNumTopClasses();
42-
validate();
43-
}
44-
45-
private void validate() {
46-
// validation occurs in TextClassificationConfig
4743
}
4844

4945
@Override
@@ -87,14 +83,15 @@ static InferenceResults processResult(
8783
String resultsField
8884
) {
8985
if (pyTorchResult.getInferenceResult().length < 1) {
90-
return new WarningInferenceResults("Text classification result has no data");
86+
throw new ElasticsearchStatusException("Text classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
9187
}
9288

9389
// TODO only the first entry in the batch result is verified and
9490
// checked. Implement for all in batch
9591
if (pyTorchResult.getInferenceResult()[0][0].length != labels.size()) {
96-
return new WarningInferenceResults(
92+
throw new ElasticsearchStatusException(
9793
"Expected exactly [{}] values in text classification result; got [{}]",
94+
RestStatus.INTERNAL_SERVER_ERROR,
9895
labels.size(),
9996
pyTorchResult.getInferenceResult()[0][0].length
10097
);

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/nlp/ZeroShotClassificationProcessor.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.logging.LoggerMessageFormat;
12+
import org.elasticsearch.rest.RestStatus;
1113
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
1214
import org.elasticsearch.xpack.core.ml.inference.results.NlpClassificationInferenceResults;
1315
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1516
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfig;
1617
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig;
1718
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
@@ -113,7 +114,7 @@ static class RequestBuilder implements NlpTask.RequestBuilder {
113114
@Override
114115
public NlpTask.Request buildRequest(List<String> inputs, String requestId) throws IOException {
115116
if (inputs.size() > 1) {
116-
throw new IllegalArgumentException("Unable to do zero-shot classification on more than one text input at a time");
117+
throw ExceptionsHelper.badRequestException("Unable to do zero-shot classification on more than one text input at a time");
117118
}
118119
List<TokenizationResult.Tokenization> tokenizations = new ArrayList<>(labels.length);
119120
for (String label : labels) {
@@ -147,13 +148,14 @@ static class ResultProcessor implements NlpTask.ResultProcessor {
147148
@Override
148149
public InferenceResults processResult(TokenizationResult tokenization, PyTorchResult pyTorchResult) {
149150
if (pyTorchResult.getInferenceResult().length < 1) {
150-
return new WarningInferenceResults("Zero shot classification result has no data");
151+
throw new ElasticsearchStatusException("Zero shot classification result has no data", RestStatus.INTERNAL_SERVER_ERROR);
151152
}
152153
// TODO only the first entry in the batch result is verified and
153154
// checked. Implement for all in batch
154155
if (pyTorchResult.getInferenceResult()[0].length != labels.length) {
155-
return new WarningInferenceResults(
156+
throw new ElasticsearchStatusException(
156157
"Expected exactly [{}] values in zero shot classification result; got [{}]",
158+
RestStatus.INTERNAL_SERVER_ERROR,
157159
labels.length,
158160
pyTorchResult.getInferenceResult().length
159161
);
@@ -164,8 +166,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe
164166
int v = 0;
165167
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
166168
if (vals.length != 3) {
167-
return new WarningInferenceResults(
169+
throw new ElasticsearchStatusException(
168170
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
171+
RestStatus.INTERNAL_SERVER_ERROR,
169172
3,
170173
vals.length
171174
);
@@ -180,8 +183,9 @@ public InferenceResults processResult(TokenizationResult tokenization, PyTorchRe
180183
int v = 0;
181184
for (double[] vals : pyTorchResult.getInferenceResult()[0]) {
182185
if (vals.length != 3) {
183-
return new WarningInferenceResults(
186+
throw new ElasticsearchStatusException(
184187
"Expected exactly [{}] values in inner zero shot classification result; got [{}]",
188+
RestStatus.INTERNAL_SERVER_ERROR,
185189
3,
186190
vals.length
187191
);

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/FillMaskProcessorTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.elasticsearch.test.ESTestCase;
1212
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
1313
import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1514
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig;
1615
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.VocabularyConfig;
1716
import org.elasticsearch.xpack.ml.inference.deployment.PyTorchResult;
@@ -25,7 +24,6 @@
2524
import static org.hamcrest.Matchers.containsString;
2625
import static org.hamcrest.Matchers.equalTo;
2726
import static org.hamcrest.Matchers.hasSize;
28-
import static org.hamcrest.Matchers.instanceOf;
2927
import static org.mockito.Mockito.mock;
3028

3129
public class FillMaskProcessorTests extends ESTestCase {
@@ -77,9 +75,9 @@ public void testProcessResults_GivenMissingTokens() {
7775
tokenization.addTokenization("", false, new String[] {}, new int[] {}, new int[] {});
7876

7977
PyTorchResult pyTorchResult = new PyTorchResult("1", new double[][][] { { {} } }, 0L, null);
80-
assertThat(
81-
FillMaskProcessor.processResult(tokenization, pyTorchResult, 5, randomAlphaOfLength(10)),
82-
instanceOf(WarningInferenceResults.class)
78+
expectThrows(
79+
ElasticsearchStatusException.class,
80+
() -> FillMaskProcessor.processResult(tokenization, pyTorchResult, 5, randomAlphaOfLength(10))
8381
);
8482
}
8583

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/NerProcessorTests.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.ValidationException;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
13-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1414
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
1515
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig;
1616
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
@@ -89,10 +89,11 @@ public void testValidate_NotAEntityLabel() {
8989
public void testProcessResults_GivenNoTokens() {
9090
NerProcessor.NerResultProcessor processor = new NerProcessor.NerResultProcessor(NerProcessor.IobTag.values(), null, false);
9191
TokenizationResult tokenization = tokenize(Collections.emptyList(), "");
92-
assertThat(
93-
processor.processResult(tokenization, new PyTorchResult("test", null, 0L, null)),
94-
instanceOf(WarningInferenceResults.class)
92+
var e = expectThrows(
93+
ElasticsearchStatusException.class,
94+
() -> processor.processResult(tokenization, new PyTorchResult("test", null, 0L, null))
9595
);
96+
assertThat(e, instanceOf(ElasticsearchStatusException.class));
9697
}
9798

9899
public void testProcessResults() {

x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/nlp/TextClassificationProcessorTests.java

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
package org.elasticsearch.xpack.ml.inference.nlp;
99

10+
import org.elasticsearch.ElasticsearchStatusException;
1011
import org.elasticsearch.common.xcontent.XContentHelper;
1112
import org.elasticsearch.test.ESTestCase;
1213
import org.elasticsearch.xcontent.XContentType;
13-
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
14-
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
1514
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.BertTokenization;
1615
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig;
1716
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.Tokenization;
@@ -25,38 +24,27 @@
2524
import java.util.List;
2625
import java.util.Map;
2726

27+
import static org.hamcrest.Matchers.containsString;
2828
import static org.hamcrest.Matchers.hasSize;
29-
import static org.hamcrest.Matchers.instanceOf;
3029

3130
public class TextClassificationProcessorTests extends ESTestCase {
3231

3332
public void testInvalidResult() {
3433
{
3534
PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] {}, 0L, null);
36-
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
37-
null,
38-
torchResult,
39-
randomInt(),
40-
List.of("a", "b"),
41-
randomAlphaOfLength(10)
35+
var e = expectThrows(
36+
ElasticsearchStatusException.class,
37+
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
4238
);
43-
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
44-
assertEquals("Text classification result has no data", ((WarningInferenceResults) inferenceResults).getWarning());
39+
assertThat(e.getMessage(), containsString("Text classification result has no data"));
4540
}
4641
{
4742
PyTorchResult torchResult = new PyTorchResult("foo", new double[][][] { { { 1.0 } } }, 0L, null);
48-
InferenceResults inferenceResults = TextClassificationProcessor.processResult(
49-
null,
50-
torchResult,
51-
randomInt(),
52-
List.of("a", "b"),
53-
randomAlphaOfLength(10)
54-
);
55-
assertThat(inferenceResults, instanceOf(WarningInferenceResults.class));
56-
assertEquals(
57-
"Expected exactly [2] values in text classification result; got [1]",
58-
((WarningInferenceResults) inferenceResults).getWarning()
43+
var e = expectThrows(
44+
ElasticsearchStatusException.class,
45+
() -> TextClassificationProcessor.processResult(null, torchResult, randomInt(), List.of("a", "b"), randomAlphaOfLength(10))
5946
);
47+
assertThat(e.getMessage(), containsString("Expected exactly [2] values in text classification result; got [1]"));
6048
}
6149
}
6250

0 commit comments

Comments
 (0)