Skip to content

Better error message when an inference model cannot be parsed due to its size #59166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

package org.elasticsearch.xpack.core.ml.inference;

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.CheckedFunction;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
Expand All @@ -16,6 +17,7 @@
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.common.xcontent.XContentParseException;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.common.xcontent.json.JsonXContent;
Expand Down Expand Up @@ -51,10 +53,29 @@ public static <T extends ToXContentObject> String deflate(T objectToCompress) th
public static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry) throws IOException {
return inflate(compressedString, parserFunction, xContentRegistry, MAX_INFLATED_BYTES);
}

static <T> T inflate(String compressedString,
CheckedFunction<XContentParser, T, IOException> parserFunction,
NamedXContentRegistry xContentRegistry,
long maxBytes) throws IOException {
try(XContentParser parser = JsonXContent.jsonXContent.createParser(xContentRegistry,
LoggingDeprecationHandler.INSTANCE,
inflate(compressedString, MAX_INFLATED_BYTES))) {
inflate(compressedString, maxBytes))) {
return parserFunction.apply(parser);
} catch (XContentParseException parseException) {
SimpleBoundedInputStream.StreamSizeExceededException streamSizeCause =
(SimpleBoundedInputStream.StreamSizeExceededException)
ExceptionsHelper.unwrap(parseException, SimpleBoundedInputStream.StreamSizeExceededException.class);

if (streamSizeCause != null) {
// The root cause is that the model is too big.
throw new IOException("Cannot parse model definition as the content is larger than the maximum stream size of ["
+ streamSizeCause.getMaxBytes() + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest");
} else {
throw parseException;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ public final class SimpleBoundedInputStream extends InputStream {
private final long maxBytes;
private long numBytes;

public static class StreamSizeExceededException extends IOException {
private final long maxBytes;

public StreamSizeExceededException(String message, long maxBytes) {
super(message);
this.maxBytes = maxBytes;
}

public long getMaxBytes() {
return maxBytes;
}
}

public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
this.in = ExceptionsHelper.requireNonNull(inputStream, "inputStream");
if (maxBytes < 0) {
Expand All @@ -31,13 +44,14 @@ public SimpleBoundedInputStream(InputStream inputStream, long maxBytes) {
/**
* A simple wrapper around the injected input stream that restricts the total number of bytes able to be read.
* @return The byte read.
* @throws IOException on failure or when byte limit is exceeded
* @throws StreamSizeExceededException when byte limit is exceeded
* @throws IOException on failure
*/
@Override
public int read() throws IOException {
// We have reached the maximum, signal stream completion.
if (numBytes >= maxBytes) {
throw new IOException("input stream exceeded maximum bytes of [" + maxBytes + "]");
throw new StreamSizeExceededException("input stream exceeded maximum bytes of [" + maxBytes + "]", maxBytes);
}
numBytes++;
return in.read();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,34 @@ public void testInflateTooLargeStream() throws IOException {
int max = firstDeflate.getBytes(StandardCharsets.UTF_8).length + 10;
IOException ex = expectThrows(IOException.class,
() -> Streams.readFully(InferenceToXContentCompressor.inflate(firstDeflate, max)));
assertThat(ex.getMessage(), equalTo("input stream exceeded maximum bytes of [" + max + "]"));
assertThat(ex.getMessage(), equalTo("" +
"input stream exceeded maximum bytes of [" + max + "]"));
}

public void testInflateGarbage() {
expectThrows(IOException.class, () -> Streams.readFully(InferenceToXContentCompressor.inflate(randomAlphaOfLength(10), 100L)));
}

public void testInflateParsingTooLargeStream() throws IOException {
TrainedModelDefinition definition = TrainedModelDefinitionTests.createRandomBuilder()
.setPreProcessors(Stream.generate(() -> randomFrom(FrequencyEncodingTests.createRandom(),
OneHotEncodingTests.createRandom(),
TargetMeanEncodingTests.createRandom()))
.limit(100)
.collect(Collectors.toList()))
.build();
String compressedString = InferenceToXContentCompressor.deflate(definition);
int max = compressedString.getBytes(StandardCharsets.UTF_8).length + 10;

IOException e = expectThrows(IOException.class, ()-> InferenceToXContentCompressor.inflate(compressedString,
parser -> TrainedModelDefinition.fromXContent(parser, true).build(),
xContentRegistry(),
max));

assertThat(e.getMessage(), equalTo("Cannot parse model definition as the content is larger than the maximum stream size of ["
+ max + "] bytes. Max stream size is 10% of the JVM heap or 1GB whichever is smallest"));
}

@Override
protected NamedXContentRegistry xContentRegistry() {
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
Expand Down