Skip to content

Commit 9a8be68

Browse files
authored
[ML] Integration test with a simple PyTorch model (#73757)
End to end to launch the pytorch inference process, load a model and evaluate. The model is a hardcoded PyTorch TorchScript model base64 encoded in the test. Results are returned on the API without any processing via the `PassThroughResultProcessor`
1 parent 0061823 commit 9a8be68

File tree

21 files changed

+550
-58
lines changed

21 files changed

+550
-58
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/MlInferenceNamedXContentProvider.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.elasticsearch.xpack.core.ml.inference.results.FillMaskResults;
2323
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
2424
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
25+
import org.elasticsearch.xpack.core.ml.inference.results.PyTorchPassThroughResults;
2526
import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults;
2627
import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults;
2728
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig;
@@ -223,6 +224,9 @@ public List<NamedWriteableRegistry.Entry> getNamedWriteables() {
223224
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
224225
FillMaskResults.NAME,
225226
FillMaskResults::new));
227+
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceResults.class,
228+
PyTorchPassThroughResults.NAME,
229+
PyTorchPassThroughResults::new));
226230

227231
// Inference Configs
228232
namedWriteables.add(new NamedWriteableRegistry.Entry(InferenceConfig.class,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.results;
9+
10+
import org.elasticsearch.common.ParseField;
11+
import org.elasticsearch.common.io.stream.StreamInput;
12+
import org.elasticsearch.common.io.stream.StreamOutput;
13+
import org.elasticsearch.common.xcontent.XContentBuilder;
14+
15+
import java.io.IOException;
16+
import java.util.Arrays;
17+
import java.util.LinkedHashMap;
18+
import java.util.Map;
19+
20+
public class PyTorchPassThroughResults implements InferenceResults {
21+
22+
public static final String NAME = "pass_through_result";
23+
static final String DEFAULT_RESULTS_FIELD = "results";
24+
25+
private static final ParseField INFERENCE = new ParseField("inference");
26+
27+
private final double[][] inference;
28+
29+
public PyTorchPassThroughResults(double[][] inference) {
30+
this.inference = inference;
31+
}
32+
33+
public PyTorchPassThroughResults(StreamInput in) throws IOException {
34+
inference = in.readArray(StreamInput::readDoubleArray, length -> new double[length][]);
35+
}
36+
37+
public double[][] getInference() {
38+
return inference;
39+
}
40+
41+
@Override
42+
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
43+
builder.startObject();
44+
builder.field(INFERENCE.getPreferredName(), inference);
45+
builder.endObject();
46+
return builder;
47+
}
48+
49+
@Override
50+
public String getWriteableName() {
51+
return NAME;
52+
}
53+
54+
@Override
55+
public void writeTo(StreamOutput out) throws IOException {
56+
out.writeArray(StreamOutput::writeDoubleArray, inference);
57+
}
58+
59+
@Override
60+
public Map<String, Object> asMap() {
61+
Map<String, Object> map = new LinkedHashMap<>();
62+
map.put(DEFAULT_RESULTS_FIELD, inference);
63+
return map;
64+
}
65+
66+
@Override
67+
public Object predictedValue() {
68+
throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value");
69+
}
70+
71+
@Override
72+
public boolean equals(Object o) {
73+
if (this == o) return true;
74+
if (o == null || getClass() != o.getClass()) return false;
75+
PyTorchPassThroughResults that = (PyTorchPassThroughResults) o;
76+
return Arrays.deepEquals(inference, that.inference);
77+
}
78+
79+
@Override
80+
public int hashCode() {
81+
return Arrays.deepHashCode(inference);
82+
}
83+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.core.ml.inference.results;
9+
10+
import org.elasticsearch.common.io.stream.Writeable;
11+
import org.elasticsearch.test.AbstractWireSerializingTestCase;
12+
13+
import java.util.Map;
14+
15+
import static org.hamcrest.Matchers.hasSize;
16+
17+
public class PyTorchPassThroughResultsTests extends AbstractWireSerializingTestCase<PyTorchPassThroughResults> {
18+
@Override
19+
protected Writeable.Reader<PyTorchPassThroughResults> instanceReader() {
20+
return PyTorchPassThroughResults::new;
21+
}
22+
23+
@Override
24+
protected PyTorchPassThroughResults createTestInstance() {
25+
int rows = randomIntBetween(1, 10);
26+
int columns = randomIntBetween(1, 10);
27+
double [][] arr = new double[rows][columns];
28+
for (int i=0; i<rows; i++) {
29+
for (int j=0; j<columns; j++) {
30+
arr[i][j] = randomDouble();
31+
}
32+
}
33+
34+
return new PyTorchPassThroughResults(arr);
35+
}
36+
37+
public void testAsMap() {
38+
PyTorchPassThroughResults testInstance = createTestInstance();
39+
Map<String, Object> asMap = testInstance.asMap();
40+
assertThat(asMap.keySet(), hasSize(1));
41+
assertArrayEquals(testInstance.getInference(), (double[][]) asMap.get(PyTorchPassThroughResults.DEFAULT_RESULTS_FIELD));
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.ml.integration;
9+
10+
import org.apache.http.util.EntityUtils;
11+
import org.elasticsearch.client.Request;
12+
import org.elasticsearch.client.Response;
13+
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.common.util.concurrent.ThreadContext;
15+
import org.elasticsearch.test.SecuritySettingsSourceField;
16+
import org.elasticsearch.test.rest.ESRestTestCase;
17+
import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken;
18+
19+
import java.io.IOException;
20+
import java.util.Base64;
21+
22+
/**
23+
* This test uses a tiny hardcoded base64 encoded PyTorch TorchScript model.
24+
* The model was created with the following python script and returns a
25+
* Tensor of 1s. The simplicity of the model is not important as the aim
26+
* is to test loading a model into the PyTorch process and evaluating it.
27+
*
28+
* ## Start Python
29+
* import torch
30+
* class SuperSimple(torch.nn.Module):
31+
* def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
32+
* return torch.ones((input_ids.size()[0], 2), dtype=torch.float32)
33+
*
34+
* model = SuperSimple()
35+
* input_ids = torch.tensor([1, 2, 3, 4, 5])
36+
* the_rest = torch.ones(5)
37+
* result = model.forward(input_ids, the_rest, the_rest, the_rest)
38+
* print(result)
39+
*
40+
* traced_model = torch.jit.trace(model, (input_ids, the_rest, the_rest, the_rest))
41+
* torch.jit.save(traced_model, "simplemodel.pt")
42+
* ## End Python
43+
*/
44+
import static org.hamcrest.Matchers.equalTo;
45+
46+
public class PyTorchModelIT extends ESRestTestCase {
47+
48+
private static final String BASIC_AUTH_VALUE_SUPER_USER =
49+
UsernamePasswordToken.basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
50+
51+
@Override
52+
protected Settings restClientSettings() {
53+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
54+
}
55+
56+
private static final String MODEL_INDEX = "model_store";
57+
private static final String MODEL_ID ="simple_model_to_evaluate";
58+
private static final String BASE_64_ENCODED_MODEL =
59+
"UEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAUAA4Ac2ltcGxlbW9kZWwvZGF0YS5wa2xGQgoAWlpaWlpaWlpaWoACY19fdG9yY2hfXwp" +
60+
"TdXBlclNpbXBsZQpxACmBfShYCAAAAHRyYWluaW5ncQGIdWJxAi5QSwcIXOpBBDQAAAA0AAAAUEsDBBQACAgIAAAAAAAAAAAAAAAAAA" +
61+
"AAAAAdAEEAc2ltcGxlbW9kZWwvY29kZS9fX3RvcmNoX18ucHlGQj0AWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaW" +
62+
"lpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWnWOMWvDMBCF9/yKI5MMrnHTQsHgjt2aJdlCEIp9SgWSTpykFvfXV1htaYds0nfv473Jqhjh" +
63+
"kAPywbhgUbzSnC02wwZAyqBYOUzIUUoY4XRe6SVr/Q8lVsYbf4UBLkS2kBk1aOIPxbOIaPVQtEQ8vUnZ/WlrSxTA+JCTNHMc4Ig+Ele" +
64+
"s+Jod+iR3N/jDDf74wxu4e/5+DmtE9mUyhdgFNq7bZ3ekehbruC6aTxS/c1rom6Z698WrEfIYxcn4JGTftLA7tzCnJeD41IJVC+U07k" +
65+
"umUHw3E47Vqh+xnULeFisYLx064mV8UTZibWFMmX0p23wBUEsHCE0EGH3yAAAAlwEAAFBLAwQUAAgICAAAAAAAAAAAAAAAAAAAAAAAJ" +
66+
"wA5AHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYnVnX3BrbEZCNQBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpa" +
67+
"WlpaWlpaWlpaWlpaWlpaWlpaWlpaWrWST0+DMBiHW6bOod/BGS94kKpo2Mwyox5x3pbgiXSAFtdR/nQu3IwHiZ9oX88CaeGu9tL0efq" +
68+
"+v8P7fmiGA1wgTgoIcECZQqe6vmYD6G4hAJOcB1E8NazTm+ELyzY4C3Q0z8MsRwF+j4JlQUPEEo5wjH0WB9hCNFqgpOCExZY5QnnEw7" +
69+
"ME+0v8GuaIs8wnKI7RigVrKkBzm0lh2OdjkeHllG28f066vK6SfEypF60S+vuYt4gjj2fYr/uPrSvRv356TepfJ9iWJRN0OaELQSZN3" +
70+
"FRPNbcP1PTSntMr0x0HzLZQjPYIEo3UaFeiISRKH0Mil+BE/dyT1m7tCBLwVO1MX4DK3bbuTlXuy8r71j5Aoho66udAoseOnrdVzx28" +
71+
"UFW6ROuO/lT6QKKyo79VU54emj9QSwcInsUTEDMBAAAFAwAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAAZAAYAc2ltcGxlbW9kZWw" +
72+
"vY29uc3RhbnRzLnBrbEZCAgBaWoACKS5QSwcIbS8JVwQAAAAEAAAAUEsDBAAACAgAAAAAAAAAAAAAAAAAAAAAAAATADsAc2ltcGxlbW" +
73+
"9kZWwvdmVyc2lvbkZCNwBaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaWlpaMwpQSwcI0" +
74+
"Z5nVQIAAAACAAAAUEsBAgAAAAAICAAAAAAAAFzqQQQ0AAAANAAAABQAAAAAAAAAAAAAAAAAAAAAAHNpbXBsZW1vZGVsL2RhdGEucGts" +
75+
"UEsBAgAAFAAICAgAAAAAAE0EGH3yAAAAlwEAAB0AAAAAAAAAAAAAAAAAhAAAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5UEs" +
76+
"BAgAAFAAICAgAAAAAAJ7FExAzAQAABQMAACcAAAAAAAAAAAAAAAAAAgIAAHNpbXBsZW1vZGVsL2NvZGUvX190b3JjaF9fLnB5LmRlYn" +
77+
"VnX3BrbFBLAQIAAAAACAgAAAAAAABtLwlXBAAAAAQAAAAZAAAAAAAAAAAAAAAAAMMDAABzaW1wbGVtb2RlbC9jb25zdGFudHMucGtsU" +
78+
"EsBAgAAAAAICAAAAAAAANGeZ1UCAAAAAgAAABMAAAAAAAAAAAAAAAAAFAQAAHNpbXBsZW1vZGVsL3ZlcnNpb25QSwYGLAAAAAAAAAAe" +
79+
"Ay0AAAAAAAAAAAAFAAAAAAAAAAUAAAAAAAAAagEAAAAAAACSBAAAAAAAAFBLBgcAAAAA/AUAAAAAAAABAAAAUEsFBgAAAAAFAAUAagE" +
80+
"AAJIEAAAAAA==";
81+
private static final int RAW_MODEL_SIZE; // size of the model before base64 encoding
82+
static {
83+
RAW_MODEL_SIZE = Base64.getDecoder().decode(BASE_64_ENCODED_MODEL).length;
84+
}
85+
86+
public void testEvaluate() throws IOException {
87+
createModelStoreIndex();
88+
putTaskConfig();
89+
putModelDefinition();
90+
createTrainedModel();
91+
startDeployment();
92+
try {
93+
Response inference = infer("my words");
94+
assertThat(EntityUtils.toString(inference.getEntity()), equalTo("{\"inference\":[[1.0,1.0]]}"));
95+
} finally {
96+
stopDeployment();
97+
}
98+
}
99+
100+
private void putModelDefinition() throws IOException {
101+
Request request = new Request("PUT", "/" + MODEL_INDEX + "/_doc/trained_model_definition_doc-" + MODEL_ID + "-0");
102+
request.setJsonEntity("{ " +
103+
"\"doc_type\": \"trained_model_definition_doc\"," +
104+
"\"model_id\": \"" + MODEL_ID +"\"," +
105+
"\"doc_num\": 0," +
106+
"\"definition_length\":" + RAW_MODEL_SIZE + "," +
107+
"\"total_definition_length\":" + RAW_MODEL_SIZE + "," +
108+
"\"compression_version\": 1," +
109+
"\"definition\": \"" + BASE_64_ENCODED_MODEL + "\"," +
110+
"\"eos\": true" +
111+
"}");
112+
client().performRequest(request);
113+
}
114+
115+
private void createModelStoreIndex() throws IOException {
116+
Request request = new Request("PUT", "/" + MODEL_INDEX);
117+
request.setJsonEntity("{ " +
118+
"\"mappings\": {\n" +
119+
" \"properties\": {\n" +
120+
" \"doc_type\": { \"type\": \"keyword\" },\n" +
121+
" \"model_id\": { \"type\": \"keyword\" },\n" +
122+
" \"definition_length\": { \"type\": \"long\" },\n" +
123+
" \"total_definition_length\": { \"type\": \"long\" },\n" +
124+
" \"compression_version\": { \"type\": \"long\" },\n" +
125+
" \"definition\": { \"type\": \"binary\" },\n" +
126+
" \"eos\": { \"type\": \"boolean\" },\n" +
127+
" \"task_type\": { \"type\": \"keyword\" },\n" +
128+
" \"vocab\": { \"type\": \"keyword\" },\n" +
129+
" \"with_special_tokens\": { \"type\": \"boolean\" },\n" +
130+
" \"do_lower_case\": { \"type\": \"boolean\" }\n" +
131+
" }\n" +
132+
" }" +
133+
"}");
134+
client().performRequest(request);
135+
}
136+
137+
private void putTaskConfig() throws IOException {
138+
Request request = new Request("PUT", "/" + MODEL_INDEX + "/_doc/" + MODEL_ID + "_task_config");
139+
request.setJsonEntity("{ " +
140+
"\"task_type\": \"bert_pass_through\",\n" +
141+
"\"with_special_tokens\": false," +
142+
"\"vocab\": [\"these\", \"are\", \"my\", \"words\"]\n" +
143+
"}");
144+
client().performRequest(request);
145+
}
146+
147+
private void createTrainedModel() throws IOException {
148+
Request request = new Request("PUT", "/_ml/trained_models/" + MODEL_ID);
149+
request.setJsonEntity("{ " +
150+
" \"description\": \"simple model for testing\",\n" +
151+
" \"model_type\": \"pytorch\",\n" +
152+
" \"inference_config\": {\n" +
153+
" \"classification\": {\n" +
154+
" \"num_top_classes\": 1\n" +
155+
" }\n" +
156+
" },\n" +
157+
" \"input\": {\n" +
158+
" \"field_names\": [\"text_field\"]\n" +
159+
" },\n" +
160+
" \"location\": {\n" +
161+
" \"index\": {\n" +
162+
" \"model_id\": \"" + MODEL_ID + "\",\n" +
163+
" \"name\": \"" + MODEL_INDEX + "\"\n" +
164+
" }\n" +
165+
" }" +
166+
"}");
167+
client().performRequest(request);
168+
}
169+
170+
private void startDeployment() throws IOException {
171+
Request request = new Request("POST", "/_ml/trained_models/" + MODEL_ID + "/deployment/_start");
172+
Response response = client().performRequest(request);
173+
logger.info("Start response: " + EntityUtils.toString(response.getEntity()));
174+
}
175+
176+
private void stopDeployment() throws IOException {
177+
Request request = new Request("POST", "/_ml/trained_models/" + MODEL_ID + "/deployment/_stop");
178+
client().performRequest(request);
179+
}
180+
181+
private Response infer(String input) throws IOException {
182+
Request request = new Request("POST", "/_ml/trained_models/" + MODEL_ID + "/deployment/_infer");
183+
request.setJsonEntity("{ " +
184+
"\"input\": \"" + input + "\"\n" +
185+
"}");
186+
return client().performRequest(request);
187+
}
188+
189+
}

0 commit comments

Comments
 (0)