Skip to content

Commit 154d392

Browse files
authored
[ML][Inference] fixing ingest IT tests (#51267) (#51312)
Converts InferenceIngestIT into a `ESRestTestCase`. closes #51201
1 parent 749123b commit 154d392

File tree

1 file changed

+116
-155
lines changed
  • x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration

1 file changed

+116
-155
lines changed

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

Lines changed: 116 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,141 +5,109 @@
55
*/
66
package org.elasticsearch.xpack.ml.integration;
77

8-
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
9-
import org.elasticsearch.action.ingest.SimulateDocumentBaseResult;
10-
import org.elasticsearch.action.ingest.SimulatePipelineResponse;
11-
import org.elasticsearch.action.search.SearchRequest;
12-
import org.elasticsearch.common.bytes.BytesArray;
13-
import org.elasticsearch.common.xcontent.DeprecationHandler;
8+
import org.apache.http.util.EntityUtils;
9+
import org.elasticsearch.client.Request;
10+
import org.elasticsearch.client.Response;
11+
import org.elasticsearch.common.bytes.BytesReference;
12+
import org.elasticsearch.common.settings.Settings;
13+
import org.elasticsearch.common.util.concurrent.ThreadContext;
1414
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
15+
import org.elasticsearch.common.xcontent.XContentBuilder;
16+
import org.elasticsearch.common.xcontent.XContentFactory;
1517
import org.elasticsearch.common.xcontent.XContentHelper;
16-
import org.elasticsearch.common.xcontent.XContentParser;
1718
import org.elasticsearch.common.xcontent.XContentType;
18-
import org.elasticsearch.index.mapper.MapperService;
19+
import org.elasticsearch.index.query.QueryBuilder;
1920
import org.elasticsearch.index.query.QueryBuilders;
20-
import org.elasticsearch.search.builder.SearchSourceBuilder;
21-
import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAction;
22-
import org.elasticsearch.xpack.core.ml.action.PutTrainedModelAction;
21+
import org.elasticsearch.test.ExternalTestCluster;
22+
import org.elasticsearch.test.SecuritySettingsSourceField;
23+
import org.elasticsearch.test.rest.ESRestTestCase;
2324
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
24-
import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
25+
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
2526
import org.junit.After;
2627
import org.junit.Before;
2728

2829
import java.io.IOException;
29-
import java.nio.charset.StandardCharsets;
3030
import java.util.HashMap;
31-
import java.util.List;
3231
import java.util.Map;
3332

33+
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
3434
import static org.hamcrest.CoreMatchers.containsString;
35-
import static org.hamcrest.Matchers.equalTo;
36-
import static org.hamcrest.Matchers.is;
3735

38-
public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
36+
/**
37+
* This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
38+
* Specifically, ensuring the accounting breaker has been reset.
39+
* It has to do with `_simulate` not anything really to do with the ML code
40+
*/
41+
public class InferenceIngestIT extends ESRestTestCase {
42+
43+
private static final String BASIC_AUTH_VALUE_SUPER_USER =
44+
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
3945

4046
@Before
4147
public void createBothModels() throws Exception {
42-
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildClassificationModel())).actionGet();
43-
client().execute(PutTrainedModelAction.INSTANCE, new PutTrainedModelAction.Request(buildRegressionModel())).actionGet();
48+
Request request = new Request("PUT", "_ml/inference/test_classification");
49+
request.setJsonEntity(CLASSIFICATION_CONFIG);
50+
client().performRequest(request);
51+
52+
request = new Request("PUT", "_ml/inference/test_regression");
53+
request.setJsonEntity(REGRESSION_CONFIG);
54+
client().performRequest(request);
55+
}
56+
57+
@Override
58+
protected Settings restClientSettings() {
59+
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE_SUPER_USER).build();
4460
}
4561

4662
@After
47-
public void deleteBothModels() {
48-
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_classification")).actionGet();
49-
client().execute(DeleteTrainedModelAction.INSTANCE, new DeleteTrainedModelAction.Request("test_regression")).actionGet();
63+
public void cleanUpData() throws Exception {
64+
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
65+
ESRestTestCase.waitForPendingTasks(adminClient());
66+
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
67+
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
5068
}
5169

5270
public void testPipelineCreationAndDeletion() throws Exception {
5371

5472
for (int i = 0; i < 10; i++) {
55-
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
56-
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
57-
XContentType.JSON).get().isAcknowledged(), is(true));
58-
59-
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
60-
.setSource(new HashMap<String, Object>(){{
61-
put("col1", randomFrom("female", "male"));
62-
put("col2", randomFrom("S", "M", "L", "XL"));
63-
put("col3", randomFrom("true", "false", "none", "other"));
64-
put("col4", randomIntBetween(0, 10));
65-
}})
66-
.setPipeline("simple_classification_pipeline")
67-
.get();
68-
69-
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
70-
is(true));
71-
72-
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
73-
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
74-
XContentType.JSON).get().isAcknowledged(), is(true));
75-
76-
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
77-
.setSource(new HashMap<String, Object>(){{
78-
put("col1", randomFrom("female", "male"));
79-
put("col2", randomFrom("S", "M", "L", "XL"));
80-
put("col3", randomFrom("true", "false", "none", "other"));
81-
put("col4", randomIntBetween(0, 10));
82-
}})
83-
.setPipeline("simple_regression_pipeline")
84-
.get();
85-
86-
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
87-
is(true));
88-
}
73+
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
74+
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
75+
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
8976

90-
assertThat(client().admin().cluster().preparePutPipeline("simple_classification_pipeline",
91-
new BytesArray(CLASSIFICATION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
92-
XContentType.JSON).get().isAcknowledged(), is(true));
77+
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
78+
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
79+
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
80+
}
9381

94-
assertThat(client().admin().cluster().preparePutPipeline("simple_regression_pipeline",
95-
new BytesArray(REGRESSION_PIPELINE.getBytes(StandardCharsets.UTF_8)),
96-
XContentType.JSON).get().isAcknowledged(), is(true));
82+
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
83+
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
9784

9885
for (int i = 0; i < 10; i++) {
99-
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
100-
.setSource(generateSourceDoc())
101-
.setPipeline("simple_classification_pipeline")
102-
.get();
103-
104-
client().prepareIndex("index_for_inference_test", MapperService.SINGLE_MAPPING_NAME)
105-
.setSource(generateSourceDoc())
106-
.setPipeline("simple_regression_pipeline")
107-
.get();
86+
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
87+
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
10888
}
10989

110-
assertThat(client().admin().cluster().prepareDeletePipeline("simple_classification_pipeline").get().isAcknowledged(),
111-
is(true));
112-
113-
assertThat(client().admin().cluster().prepareDeletePipeline("simple_regression_pipeline").get().isAcknowledged(),
114-
is(true));
115-
116-
client().admin().indices().refresh(new RefreshRequest("index_for_inference_test")).get();
117-
118-
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
119-
.source(new SearchSourceBuilder()
120-
.size(0)
121-
.trackTotalHits(true)
122-
.query(QueryBuilders.boolQuery()
123-
.filter(
124-
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))))).get().getHits().getTotalHits().value,
125-
equalTo(20L));
126-
127-
assertThat(client().search(new SearchRequest().indices("index_for_inference_test")
128-
.source(new SearchSourceBuilder()
129-
.size(0)
130-
.trackTotalHits(true)
131-
.query(QueryBuilders.boolQuery()
132-
.filter(
133-
QueryBuilders.existsQuery("ml.inference.classification.predicted_value")))))
134-
.get()
135-
.getHits()
136-
.getTotalHits()
137-
.value,
138-
equalTo(20L));
90+
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
91+
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
13992

93+
client().performRequest(new Request("POST", "index_for_inference_test/_refresh"));
94+
95+
96+
Response searchResponse = client().performRequest(searchRequest("index_for_inference_test",
97+
QueryBuilders.boolQuery()
98+
.filter(
99+
QueryBuilders.existsQuery("ml.inference.regression.predicted_value"))));
100+
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
101+
102+
searchResponse = client().performRequest(searchRequest("index_for_inference_test",
103+
QueryBuilders.boolQuery()
104+
.filter(
105+
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
106+
107+
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":20"));
140108
}
141109

142-
public void testSimulate() {
110+
public void testSimulate() throws IOException {
143111
String source = "{\n" +
144112
" \"pipeline\": {\n" +
145113
" \"processors\": [\n" +
@@ -181,15 +149,10 @@ public void testSimulate() {
181149
" }}]\n" +
182150
"}";
183151

184-
SimulatePipelineResponse response = client().admin().cluster()
185-
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
186-
XContentType.JSON).get();
187-
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
188-
assertThat(baseResult.getIngestDocument().getFieldValue("ml.regression.predicted_value", Double.class), equalTo(1.0));
189-
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.predicted_value", String.class),
190-
equalTo("second"));
191-
assertThat(baseResult.getIngestDocument().getFieldValue("ml.classification.result_class_prob", List.class).size(),
192-
equalTo(2));
152+
Response response = client().performRequest(simulateRequest(source));
153+
String responseString = EntityUtils.toString(response.getEntity());
154+
assertThat(responseString, containsString("\"predicted_value\":\"second\""));
155+
assertThat(responseString, containsString("\"predicted_value\":1.0"));
193156

194157
String sourceWithMissingModel = "{\n" +
195158
" \"pipeline\": {\n" +
@@ -217,15 +180,13 @@ public void testSimulate() {
217180
" }}]\n" +
218181
"}";
219182

220-
response = client().admin().cluster()
221-
.prepareSimulatePipeline(new BytesArray(sourceWithMissingModel.getBytes(StandardCharsets.UTF_8)),
222-
XContentType.JSON).get();
183+
response = client().performRequest(simulateRequest(sourceWithMissingModel));
184+
responseString = EntityUtils.toString(response.getEntity());
223185

224-
assertThat(((SimulateDocumentBaseResult) response.getResults().get(0)).getFailure().getMessage(),
225-
containsString("Could not find trained model [test_classification_missing]"));
186+
assertThat(responseString, containsString("Could not find trained model [test_classification_missing]"));
226187
}
227188

228-
public void testSimulateLangIdent() {
189+
public void testSimulateLangIdent() throws IOException {
229190
String source = "{\n" +
230191
" \"pipeline\": {\n" +
231192
" \"processors\": [\n" +
@@ -244,11 +205,43 @@ public void testSimulateLangIdent() {
244205
" }}]\n" +
245206
"}";
246207

247-
SimulatePipelineResponse response = client().admin().cluster()
248-
.prepareSimulatePipeline(new BytesArray(source.getBytes(StandardCharsets.UTF_8)),
249-
XContentType.JSON).get();
250-
SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult)response.getResults().get(0);
251-
assertThat(baseResult.getIngestDocument().getFieldValue("ml.inference.predicted_value", String.class), equalTo("en"));
208+
Response response = client().performRequest(simulateRequest(source));
209+
assertThat(EntityUtils.toString(response.getEntity()), containsString("\"predicted_value\":\"en\""));
210+
}
211+
212+
private static Request simulateRequest(String jsonEntity) {
213+
Request request = new Request("POST", "_ingest/pipeline/_simulate");
214+
request.setJsonEntity(jsonEntity);
215+
return request;
216+
}
217+
218+
private static Request indexRequest(String index, String pipeline, Map<String, Object> doc) throws IOException {
219+
try(XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().map(doc)) {
220+
return indexRequest(index,
221+
pipeline,
222+
XContentHelper.convertToJson(BytesReference.bytes(xContentBuilder), false, XContentType.JSON));
223+
}
224+
}
225+
226+
private static Request indexRequest(String index, String pipeline, String doc) {
227+
Request request = new Request("POST", index + "/_doc?pipeline=" + pipeline);
228+
request.setJsonEntity(doc);
229+
return request;
230+
}
231+
232+
private static Request putPipeline(String pipelineId, String pipelineDefinition) {
233+
Request request = new Request("PUT", "_ingest/pipeline/" + pipelineId);
234+
request.setJsonEntity(pipelineDefinition);
235+
return request;
236+
}
237+
238+
private static Request searchRequest(String index, QueryBuilder queryBuilder) throws IOException {
239+
BytesReference reference = XContentHelper.toXContent(queryBuilder, XContentType.JSON, false);
240+
String queryJson = XContentHelper.convertToJson(reference, false, XContentType.JSON);
241+
String json = "{\"query\": " + queryJson + "}";
242+
Request request = new Request("GET", index + "/_search?track_total_hits=true");
243+
request.setJsonEntity(json);
244+
return request;
252245
}
253246

254247
private Map<String, Object> generateSourceDoc() {
@@ -380,16 +373,9 @@ private Map<String, Object> generateSourceDoc() {
380373
"}";
381374

382375
private static final String REGRESSION_CONFIG = "{" +
383-
" \"model_id\": \"test_regression\",\n" +
384376
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
385377
" \"description\": \"test model for regression\",\n" +
386-
" \"version\": \"7.6.0\",\n" +
387-
" \"definition\": " + REGRESSION_DEFINITION + ","+
388-
" \"license_level\": \"platinum\",\n" +
389-
" \"created_by\": \"ml_test\",\n" +
390-
" \"estimated_heap_memory_usage_bytes\": 0," +
391-
" \"estimated_operations\": 0," +
392-
" \"created_time\": 0" +
378+
" \"definition\": " + REGRESSION_DEFINITION +
393379
"}";
394380

395381
private static final String CLASSIFICATION_DEFINITION = "{" +
@@ -512,41 +498,16 @@ private Map<String, Object> generateSourceDoc() {
512498
" }\n" +
513499
"}";
514500

515-
private TrainedModelConfig buildClassificationModel() throws IOException {
516-
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
517-
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
518-
new BytesArray(CLASSIFICATION_CONFIG),
519-
XContentType.JSON)) {
520-
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
521-
}
522-
}
523-
524-
private TrainedModelConfig buildRegressionModel() throws IOException {
525-
try (XContentParser parser = XContentHelper.createParser(xContentRegistry(),
526-
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
527-
new BytesArray(REGRESSION_CONFIG),
528-
XContentType.JSON)) {
529-
return TrainedModelConfig.LENIENT_PARSER.apply(parser, null).build();
530-
}
531-
}
532-
533501
@Override
534502
protected NamedXContentRegistry xContentRegistry() {
535503
return new NamedXContentRegistry(new MlInferenceNamedXContentProvider().getNamedXContentParsers());
536504
}
537505

538506
private static final String CLASSIFICATION_CONFIG = "" +
539507
"{\n" +
540-
" \"model_id\": \"test_classification\",\n" +
541508
" \"input\":{\"field_names\":[\"col1\",\"col2\",\"col3\",\"col4\"]}," +
542509
" \"description\": \"test model for classification\",\n" +
543-
" \"version\": \"7.6.0\",\n" +
544-
" \"definition\": " + CLASSIFICATION_DEFINITION + ","+
545-
" \"license_level\": \"platinum\",\n" +
546-
" \"created_by\": \"es_test\",\n" +
547-
" \"estimated_heap_memory_usage_bytes\": 0," +
548-
" \"estimated_operations\": 0," +
549-
" \"created_time\": 0\n" +
510+
" \"definition\": " + CLASSIFICATION_DEFINITION +
550511
"}";
551512

552513
private static final String CLASSIFICATION_PIPELINE = "{" +

0 commit comments

Comments
 (0)