5
5
*/
6
6
package org .elasticsearch .xpack .ml .integration ;
7
7
8
- import org .elasticsearch . action . admin . indices . refresh . RefreshRequest ;
9
- import org .elasticsearch .action . ingest . SimulateDocumentBaseResult ;
10
- import org .elasticsearch .action . ingest . SimulatePipelineResponse ;
11
- import org .elasticsearch .action . search . SearchRequest ;
12
- import org .elasticsearch .common .bytes . BytesArray ;
13
- import org .elasticsearch .common .xcontent . DeprecationHandler ;
8
+ import org .apache . http . util . EntityUtils ;
9
+ import org .elasticsearch .client . Request ;
10
+ import org .elasticsearch .client . Response ;
11
+ import org .elasticsearch .common . bytes . BytesReference ;
12
+ import org .elasticsearch .common .settings . Settings ;
13
+ import org .elasticsearch .common .util . concurrent . ThreadContext ;
14
14
import org .elasticsearch .common .xcontent .NamedXContentRegistry ;
15
+ import org .elasticsearch .common .xcontent .XContentBuilder ;
16
+ import org .elasticsearch .common .xcontent .XContentFactory ;
15
17
import org .elasticsearch .common .xcontent .XContentHelper ;
16
- import org .elasticsearch .common .xcontent .XContentParser ;
17
18
import org .elasticsearch .common .xcontent .XContentType ;
18
- import org .elasticsearch .index .mapper . MapperService ;
19
+ import org .elasticsearch .index .query . QueryBuilder ;
19
20
import org .elasticsearch .index .query .QueryBuilders ;
20
- import org .elasticsearch .search . builder . SearchSourceBuilder ;
21
- import org .elasticsearch .xpack . core . ml . action . DeleteTrainedModelAction ;
22
- import org .elasticsearch .xpack . core . ml . action . PutTrainedModelAction ;
21
+ import org .elasticsearch .test . ExternalTestCluster ;
22
+ import org .elasticsearch .test . SecuritySettingsSourceField ;
23
+ import org .elasticsearch .test . rest . ESRestTestCase ;
23
24
import org .elasticsearch .xpack .core .ml .inference .MlInferenceNamedXContentProvider ;
24
- import org .elasticsearch .xpack .core .ml .inference . TrainedModelConfig ;
25
+ import org .elasticsearch .xpack .core .ml .integration . MlRestTestStateCleaner ;
25
26
import org .junit .After ;
26
27
import org .junit .Before ;
27
28
28
29
import java .io .IOException ;
29
- import java .nio .charset .StandardCharsets ;
30
30
import java .util .HashMap ;
31
- import java .util .List ;
32
31
import java .util .Map ;
33
32
33
+ import static org .elasticsearch .xpack .core .security .authc .support .UsernamePasswordToken .basicAuthHeaderValue ;
34
34
import static org .hamcrest .CoreMatchers .containsString ;
35
- import static org .hamcrest .Matchers .equalTo ;
36
- import static org .hamcrest .Matchers .is ;
37
35
38
- public class InferenceIngestIT extends MlNativeAutodetectIntegTestCase {
36
+ /**
37
+ * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
38
+ * Specifically, ensuring the accounting breaker has been reset.
39
+ * It has to do with `_simulate` not anything really to do with the ML code
40
+ */
41
+ public class InferenceIngestIT extends ESRestTestCase {
42
+
43
+ private static final String BASIC_AUTH_VALUE_SUPER_USER =
44
+ basicAuthHeaderValue ("x_pack_rest_user" , SecuritySettingsSourceField .TEST_PASSWORD_SECURE_STRING );
39
45
40
46
@ Before
41
47
public void createBothModels () throws Exception {
42
- client ().execute (PutTrainedModelAction .INSTANCE , new PutTrainedModelAction .Request (buildClassificationModel ())).actionGet ();
43
- client ().execute (PutTrainedModelAction .INSTANCE , new PutTrainedModelAction .Request (buildRegressionModel ())).actionGet ();
48
+ Request request = new Request ("PUT" , "_ml/inference/test_classification" );
49
+ request .setJsonEntity (CLASSIFICATION_CONFIG );
50
+ client ().performRequest (request );
51
+
52
+ request = new Request ("PUT" , "_ml/inference/test_regression" );
53
+ request .setJsonEntity (REGRESSION_CONFIG );
54
+ client ().performRequest (request );
55
+ }
56
+
57
+ @ Override
58
+ protected Settings restClientSettings () {
59
+ return Settings .builder ().put (ThreadContext .PREFIX + ".Authorization" , BASIC_AUTH_VALUE_SUPER_USER ).build ();
44
60
}
45
61
46
62
@ After
47
- public void deleteBothModels () {
48
- client ().execute (DeleteTrainedModelAction .INSTANCE , new DeleteTrainedModelAction .Request ("test_classification" )).actionGet ();
49
- client ().execute (DeleteTrainedModelAction .INSTANCE , new DeleteTrainedModelAction .Request ("test_regression" )).actionGet ();
63
+ public void cleanUpData () throws Exception {
64
+ new MlRestTestStateCleaner (logger , adminClient ()).clearMlMetadata ();
65
+ ESRestTestCase .waitForPendingTasks (adminClient ());
66
+ client ().performRequest (new Request ("DELETE" , "_ml/inference/test_classification" ));
67
+ client ().performRequest (new Request ("DELETE" , "_ml/inference/test_regression" ));
50
68
}
51
69
52
70
public void testPipelineCreationAndDeletion () throws Exception {
53
71
54
72
for (int i = 0 ; i < 10 ; i ++) {
55
- assertThat (client ().admin ().cluster ().preparePutPipeline ("simple_classification_pipeline" ,
56
- new BytesArray (CLASSIFICATION_PIPELINE .getBytes (StandardCharsets .UTF_8 )),
57
- XContentType .JSON ).get ().isAcknowledged (), is (true ));
58
-
59
- client ().prepareIndex ("index_for_inference_test" , MapperService .SINGLE_MAPPING_NAME )
60
- .setSource (new HashMap <String , Object >(){{
61
- put ("col1" , randomFrom ("female" , "male" ));
62
- put ("col2" , randomFrom ("S" , "M" , "L" , "XL" ));
63
- put ("col3" , randomFrom ("true" , "false" , "none" , "other" ));
64
- put ("col4" , randomIntBetween (0 , 10 ));
65
- }})
66
- .setPipeline ("simple_classification_pipeline" )
67
- .get ();
68
-
69
- assertThat (client ().admin ().cluster ().prepareDeletePipeline ("simple_classification_pipeline" ).get ().isAcknowledged (),
70
- is (true ));
71
-
72
- assertThat (client ().admin ().cluster ().preparePutPipeline ("simple_regression_pipeline" ,
73
- new BytesArray (REGRESSION_PIPELINE .getBytes (StandardCharsets .UTF_8 )),
74
- XContentType .JSON ).get ().isAcknowledged (), is (true ));
75
-
76
- client ().prepareIndex ("index_for_inference_test" , MapperService .SINGLE_MAPPING_NAME )
77
- .setSource (new HashMap <String , Object >(){{
78
- put ("col1" , randomFrom ("female" , "male" ));
79
- put ("col2" , randomFrom ("S" , "M" , "L" , "XL" ));
80
- put ("col3" , randomFrom ("true" , "false" , "none" , "other" ));
81
- put ("col4" , randomIntBetween (0 , 10 ));
82
- }})
83
- .setPipeline ("simple_regression_pipeline" )
84
- .get ();
85
-
86
- assertThat (client ().admin ().cluster ().prepareDeletePipeline ("simple_regression_pipeline" ).get ().isAcknowledged (),
87
- is (true ));
88
- }
73
+ client ().performRequest (putPipeline ("simple_classification_pipeline" , CLASSIFICATION_PIPELINE ));
74
+ client ().performRequest (indexRequest ("index_for_inference_test" , "simple_classification_pipeline" , generateSourceDoc ()));
75
+ client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_classification_pipeline" ));
89
76
90
- assertThat (client ().admin ().cluster ().preparePutPipeline ("simple_classification_pipeline" ,
91
- new BytesArray (CLASSIFICATION_PIPELINE .getBytes (StandardCharsets .UTF_8 )),
92
- XContentType .JSON ).get ().isAcknowledged (), is (true ));
77
+ client ().performRequest (putPipeline ("simple_regression_pipeline" , REGRESSION_PIPELINE ));
78
+ client ().performRequest (indexRequest ("index_for_inference_test" , "simple_regression_pipeline" , generateSourceDoc ()));
79
+ client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_regression_pipeline" ));
80
+ }
93
81
94
- assertThat (client ().admin ().cluster ().preparePutPipeline ("simple_regression_pipeline" ,
95
- new BytesArray (REGRESSION_PIPELINE .getBytes (StandardCharsets .UTF_8 )),
96
- XContentType .JSON ).get ().isAcknowledged (), is (true ));
82
+ client ().performRequest (putPipeline ("simple_classification_pipeline" , CLASSIFICATION_PIPELINE ));
83
+ client ().performRequest (putPipeline ("simple_regression_pipeline" , REGRESSION_PIPELINE ));
97
84
98
85
for (int i = 0 ; i < 10 ; i ++) {
99
- client ().prepareIndex ("index_for_inference_test" , MapperService .SINGLE_MAPPING_NAME )
100
- .setSource (generateSourceDoc ())
101
- .setPipeline ("simple_classification_pipeline" )
102
- .get ();
103
-
104
- client ().prepareIndex ("index_for_inference_test" , MapperService .SINGLE_MAPPING_NAME )
105
- .setSource (generateSourceDoc ())
106
- .setPipeline ("simple_regression_pipeline" )
107
- .get ();
86
+ client ().performRequest (indexRequest ("index_for_inference_test" , "simple_classification_pipeline" , generateSourceDoc ()));
87
+ client ().performRequest (indexRequest ("index_for_inference_test" , "simple_regression_pipeline" , generateSourceDoc ()));
108
88
}
109
89
110
- assertThat (client ().admin ().cluster ().prepareDeletePipeline ("simple_classification_pipeline" ).get ().isAcknowledged (),
111
- is (true ));
112
-
113
- assertThat (client ().admin ().cluster ().prepareDeletePipeline ("simple_regression_pipeline" ).get ().isAcknowledged (),
114
- is (true ));
115
-
116
- client ().admin ().indices ().refresh (new RefreshRequest ("index_for_inference_test" )).get ();
117
-
118
- assertThat (client ().search (new SearchRequest ().indices ("index_for_inference_test" )
119
- .source (new SearchSourceBuilder ()
120
- .size (0 )
121
- .trackTotalHits (true )
122
- .query (QueryBuilders .boolQuery ()
123
- .filter (
124
- QueryBuilders .existsQuery ("ml.inference.regression.predicted_value" ))))).get ().getHits ().getTotalHits ().value ,
125
- equalTo (20L ));
126
-
127
- assertThat (client ().search (new SearchRequest ().indices ("index_for_inference_test" )
128
- .source (new SearchSourceBuilder ()
129
- .size (0 )
130
- .trackTotalHits (true )
131
- .query (QueryBuilders .boolQuery ()
132
- .filter (
133
- QueryBuilders .existsQuery ("ml.inference.classification.predicted_value" )))))
134
- .get ()
135
- .getHits ()
136
- .getTotalHits ()
137
- .value ,
138
- equalTo (20L ));
90
+ client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_regression_pipeline" ));
91
+ client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_classification_pipeline" ));
139
92
93
+ client ().performRequest (new Request ("POST" , "index_for_inference_test/_refresh" ));
94
+
95
+
96
+ Response searchResponse = client ().performRequest (searchRequest ("index_for_inference_test" ,
97
+ QueryBuilders .boolQuery ()
98
+ .filter (
99
+ QueryBuilders .existsQuery ("ml.inference.regression.predicted_value" ))));
100
+ assertThat (EntityUtils .toString (searchResponse .getEntity ()), containsString ("\" value\" :20" ));
101
+
102
+ searchResponse = client ().performRequest (searchRequest ("index_for_inference_test" ,
103
+ QueryBuilders .boolQuery ()
104
+ .filter (
105
+ QueryBuilders .existsQuery ("ml.inference.classification.predicted_value" ))));
106
+
107
+ assertThat (EntityUtils .toString (searchResponse .getEntity ()), containsString ("\" value\" :20" ));
140
108
}
141
109
142
- public void testSimulate () {
110
+ public void testSimulate () throws IOException {
143
111
String source = "{\n " +
144
112
" \" pipeline\" : {\n " +
145
113
" \" processors\" : [\n " +
@@ -181,15 +149,10 @@ public void testSimulate() {
181
149
" }}]\n " +
182
150
"}" ;
183
151
184
- SimulatePipelineResponse response = client ().admin ().cluster ()
185
- .prepareSimulatePipeline (new BytesArray (source .getBytes (StandardCharsets .UTF_8 )),
186
- XContentType .JSON ).get ();
187
- SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult )response .getResults ().get (0 );
188
- assertThat (baseResult .getIngestDocument ().getFieldValue ("ml.regression.predicted_value" , Double .class ), equalTo (1.0 ));
189
- assertThat (baseResult .getIngestDocument ().getFieldValue ("ml.classification.predicted_value" , String .class ),
190
- equalTo ("second" ));
191
- assertThat (baseResult .getIngestDocument ().getFieldValue ("ml.classification.result_class_prob" , List .class ).size (),
192
- equalTo (2 ));
152
+ Response response = client ().performRequest (simulateRequest (source ));
153
+ String responseString = EntityUtils .toString (response .getEntity ());
154
+ assertThat (responseString , containsString ("\" predicted_value\" :\" second\" " ));
155
+ assertThat (responseString , containsString ("\" predicted_value\" :1.0" ));
193
156
194
157
String sourceWithMissingModel = "{\n " +
195
158
" \" pipeline\" : {\n " +
@@ -217,15 +180,13 @@ public void testSimulate() {
217
180
" }}]\n " +
218
181
"}" ;
219
182
220
- response = client ().admin ().cluster ()
221
- .prepareSimulatePipeline (new BytesArray (sourceWithMissingModel .getBytes (StandardCharsets .UTF_8 )),
222
- XContentType .JSON ).get ();
183
+ response = client ().performRequest (simulateRequest (sourceWithMissingModel ));
184
+ responseString = EntityUtils .toString (response .getEntity ());
223
185
224
- assertThat (((SimulateDocumentBaseResult ) response .getResults ().get (0 )).getFailure ().getMessage (),
225
- containsString ("Could not find trained model [test_classification_missing]" ));
186
+ assertThat (responseString , containsString ("Could not find trained model [test_classification_missing]" ));
226
187
}
227
188
228
- public void testSimulateLangIdent () {
189
+ public void testSimulateLangIdent () throws IOException {
229
190
String source = "{\n " +
230
191
" \" pipeline\" : {\n " +
231
192
" \" processors\" : [\n " +
@@ -244,11 +205,43 @@ public void testSimulateLangIdent() {
244
205
" }}]\n " +
245
206
"}" ;
246
207
247
- SimulatePipelineResponse response = client ().admin ().cluster ()
248
- .prepareSimulatePipeline (new BytesArray (source .getBytes (StandardCharsets .UTF_8 )),
249
- XContentType .JSON ).get ();
250
- SimulateDocumentBaseResult baseResult = (SimulateDocumentBaseResult )response .getResults ().get (0 );
251
- assertThat (baseResult .getIngestDocument ().getFieldValue ("ml.inference.predicted_value" , String .class ), equalTo ("en" ));
208
+ Response response = client ().performRequest (simulateRequest (source ));
209
+ assertThat (EntityUtils .toString (response .getEntity ()), containsString ("\" predicted_value\" :\" en\" " ));
210
+ }
211
+
212
+ private static Request simulateRequest (String jsonEntity ) {
213
+ Request request = new Request ("POST" , "_ingest/pipeline/_simulate" );
214
+ request .setJsonEntity (jsonEntity );
215
+ return request ;
216
+ }
217
+
218
+ private static Request indexRequest (String index , String pipeline , Map <String , Object > doc ) throws IOException {
219
+ try (XContentBuilder xContentBuilder = XContentFactory .jsonBuilder ().map (doc )) {
220
+ return indexRequest (index ,
221
+ pipeline ,
222
+ XContentHelper .convertToJson (BytesReference .bytes (xContentBuilder ), false , XContentType .JSON ));
223
+ }
224
+ }
225
+
226
+ private static Request indexRequest (String index , String pipeline , String doc ) {
227
+ Request request = new Request ("POST" , index + "/_doc?pipeline=" + pipeline );
228
+ request .setJsonEntity (doc );
229
+ return request ;
230
+ }
231
+
232
+ private static Request putPipeline (String pipelineId , String pipelineDefinition ) {
233
+ Request request = new Request ("PUT" , "_ingest/pipeline/" + pipelineId );
234
+ request .setJsonEntity (pipelineDefinition );
235
+ return request ;
236
+ }
237
+
238
+ private static Request searchRequest (String index , QueryBuilder queryBuilder ) throws IOException {
239
+ BytesReference reference = XContentHelper .toXContent (queryBuilder , XContentType .JSON , false );
240
+ String queryJson = XContentHelper .convertToJson (reference , false , XContentType .JSON );
241
+ String json = "{\" query\" : " + queryJson + "}" ;
242
+ Request request = new Request ("GET" , index + "/_search?track_total_hits=true" );
243
+ request .setJsonEntity (json );
244
+ return request ;
252
245
}
253
246
254
247
private Map <String , Object > generateSourceDoc () {
@@ -380,16 +373,9 @@ private Map<String, Object> generateSourceDoc() {
380
373
"}" ;
381
374
382
375
private static final String REGRESSION_CONFIG = "{" +
383
- " \" model_id\" : \" test_regression\" ,\n " +
384
376
" \" input\" :{\" field_names\" :[\" col1\" ,\" col2\" ,\" col3\" ,\" col4\" ]}," +
385
377
" \" description\" : \" test model for regression\" ,\n " +
386
- " \" version\" : \" 7.6.0\" ,\n " +
387
- " \" definition\" : " + REGRESSION_DEFINITION + "," +
388
- " \" license_level\" : \" platinum\" ,\n " +
389
- " \" created_by\" : \" ml_test\" ,\n " +
390
- " \" estimated_heap_memory_usage_bytes\" : 0," +
391
- " \" estimated_operations\" : 0," +
392
- " \" created_time\" : 0" +
378
+ " \" definition\" : " + REGRESSION_DEFINITION +
393
379
"}" ;
394
380
395
381
private static final String CLASSIFICATION_DEFINITION = "{" +
@@ -512,41 +498,16 @@ private Map<String, Object> generateSourceDoc() {
512
498
" }\n " +
513
499
"}" ;
514
500
515
- private TrainedModelConfig buildClassificationModel () throws IOException {
516
- try (XContentParser parser = XContentHelper .createParser (xContentRegistry (),
517
- DeprecationHandler .THROW_UNSUPPORTED_OPERATION ,
518
- new BytesArray (CLASSIFICATION_CONFIG ),
519
- XContentType .JSON )) {
520
- return TrainedModelConfig .LENIENT_PARSER .apply (parser , null ).build ();
521
- }
522
- }
523
-
524
- private TrainedModelConfig buildRegressionModel () throws IOException {
525
- try (XContentParser parser = XContentHelper .createParser (xContentRegistry (),
526
- DeprecationHandler .THROW_UNSUPPORTED_OPERATION ,
527
- new BytesArray (REGRESSION_CONFIG ),
528
- XContentType .JSON )) {
529
- return TrainedModelConfig .LENIENT_PARSER .apply (parser , null ).build ();
530
- }
531
- }
532
-
533
501
@ Override
534
502
protected NamedXContentRegistry xContentRegistry () {
535
503
return new NamedXContentRegistry (new MlInferenceNamedXContentProvider ().getNamedXContentParsers ());
536
504
}
537
505
538
506
private static final String CLASSIFICATION_CONFIG = "" +
539
507
"{\n " +
540
- " \" model_id\" : \" test_classification\" ,\n " +
541
508
" \" input\" :{\" field_names\" :[\" col1\" ,\" col2\" ,\" col3\" ,\" col4\" ]}," +
542
509
" \" description\" : \" test model for classification\" ,\n " +
543
- " \" version\" : \" 7.6.0\" ,\n " +
544
- " \" definition\" : " + CLASSIFICATION_DEFINITION + "," +
545
- " \" license_level\" : \" platinum\" ,\n " +
546
- " \" created_by\" : \" es_test\" ,\n " +
547
- " \" estimated_heap_memory_usage_bytes\" : 0," +
548
- " \" estimated_operations\" : 0," +
549
- " \" created_time\" : 0\n " +
510
+ " \" definition\" : " + CLASSIFICATION_DEFINITION +
550
511
"}" ;
551
512
552
513
private static final String CLASSIFICATION_PIPELINE = "{" +
0 commit comments