22
22
import org .elasticsearch .test .ExternalTestCluster ;
23
23
import org .elasticsearch .test .SecuritySettingsSourceField ;
24
24
import org .elasticsearch .test .rest .ESRestTestCase ;
25
+ import org .elasticsearch .xpack .core .ml .MlStatsIndex ;
25
26
import org .elasticsearch .xpack .core .ml .inference .MlInferenceNamedXContentProvider ;
27
+ import org .elasticsearch .xpack .core .ml .inference .persistence .InferenceIndexConstants ;
26
28
import org .elasticsearch .xpack .core .ml .integration .MlRestTestStateCleaner ;
27
29
import org .junit .After ;
28
30
import org .junit .Before ;
@@ -46,14 +48,16 @@ public class InferenceIngestIT extends ESRestTestCase {
46
48
basicAuthHeaderValue ("x_pack_rest_user" , SecuritySettingsSourceField .TEST_PASSWORD_SECURE_STRING );
47
49
48
50
@ Before
49
- public void createBothModels () throws Exception {
50
- Request request = new Request ("PUT" , "_ml/inference/test_classification" );
51
- request .setJsonEntity (CLASSIFICATION_CONFIG );
52
- client ().performRequest (request );
53
-
54
- request = new Request ("PUT" , "_ml/inference/test_regression" );
55
- request .setJsonEntity (REGRESSION_CONFIG );
56
- client ().performRequest (request );
51
+ public void setup () throws Exception {
52
+ Request loggingSettings = new Request ("PUT" , "_cluster/settings" );
53
+ loggingSettings .setJsonEntity ("" +
54
+ "{" +
55
+ "\" transient\" : {\n " +
56
+ " \" logger.org.elasticsearch.xpack.ml.inference\" : \" TRACE\" \n " +
57
+ " }" +
58
+ "}" );
59
+ client ().performRequest (loggingSettings );
60
+ client ().performRequest (new Request ("GET" , "/_cluster/health?wait_for_status=green&timeout=30s" ));
57
61
}
58
62
59
63
@ Override
@@ -64,19 +68,33 @@ protected Settings restClientSettings() {
64
68
@ After
65
69
public void cleanUpData () throws Exception {
66
70
new MlRestTestStateCleaner (logger , adminClient ()).clearMlMetadata ();
71
+ client ().performRequest (new Request ("DELETE" , InferenceIndexConstants .INDEX_PATTERN ));
72
+ client ().performRequest (new Request ("DELETE" , MlStatsIndex .indexPattern ()));
73
+ Request loggingSettings = new Request ("PUT" , "_cluster/settings" );
74
+ loggingSettings .setJsonEntity ("" +
75
+ "{" +
76
+ "\" transient\" : {\n " +
77
+ " \" logger.org.elasticsearch.xpack.ml.inference\" : null\n " +
78
+ " }" +
79
+ "}" );
80
+ client ().performRequest (loggingSettings );
67
81
ESRestTestCase .waitForPendingTasks (adminClient ());
68
- client ().performRequest (new Request ("DELETE" , "_ml/inference/test_classification" ));
69
- client ().performRequest (new Request ("DELETE" , "_ml/inference/test_regression" ));
70
82
}
71
83
72
84
public void testPathologicalPipelineCreationAndDeletion () throws Exception {
85
+ String classificationModelId = "test_pathological_classification" ;
86
+ putModel (classificationModelId , CLASSIFICATION_CONFIG );
87
+
88
+ String regressionModelId = "test_pathological_regression" ;
89
+ putModel (regressionModelId , REGRESSION_CONFIG );
73
90
74
91
for (int i = 0 ; i < 10 ; i ++) {
75
- client ().performRequest (putPipeline ("simple_classification_pipeline" , CLASSIFICATION_PIPELINE ));
92
+ client ().performRequest (putPipeline ("simple_classification_pipeline" ,
93
+ pipelineDefinition (classificationModelId , "classification" )));
76
94
client ().performRequest (indexRequest ("index_for_inference_test" , "simple_classification_pipeline" , generateSourceDoc ()));
77
95
client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_classification_pipeline" ));
78
96
79
- client ().performRequest (putPipeline ("simple_regression_pipeline" , REGRESSION_PIPELINE ));
97
+ client ().performRequest (putPipeline ("simple_regression_pipeline" , pipelineDefinition ( regressionModelId , "regression" ) ));
80
98
client ().performRequest (indexRequest ("index_for_inference_test" , "simple_regression_pipeline" , generateSourceDoc ()));
81
99
client ().performRequest (new Request ("DELETE" , "_ingest/pipeline/simple_regression_pipeline" ));
82
100
}
@@ -94,13 +112,30 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
94
112
QueryBuilders .existsQuery ("ml.inference.classification.predicted_value" ))));
95
113
96
114
assertThat (EntityUtils .toString (searchResponse .getEntity ()), containsString ("\" value\" :10" ));
115
+ assertBusy (() -> {
116
+ try {
117
+ Response statsResponse = client ().performRequest (new Request ("GET" ,
118
+ "_ml/inference/" + classificationModelId + "/_stats" ));
119
+ assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :10" ));
120
+ statsResponse = client ().performRequest (new Request ("GET" , "_ml/inference/" + regressionModelId + "/_stats" ));
121
+ assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :10" ));
122
+ } catch (ResponseException ex ) {
123
+ //this could just mean shard failures.
124
+ fail (ex .getMessage ());
125
+ }
126
+ }, 30 , TimeUnit .SECONDS );
97
127
}
98
128
99
- @ AwaitsFix (bugUrl = "https://github.com/elastic/elasticsearch/issues/54786" )
100
129
public void testPipelineIngest () throws Exception {
130
+ String classificationModelId = "test_classification" ;
131
+ putModel (classificationModelId , CLASSIFICATION_CONFIG );
101
132
102
- client ().performRequest (putPipeline ("simple_classification_pipeline" , CLASSIFICATION_PIPELINE ));
103
- client ().performRequest (putPipeline ("simple_regression_pipeline" , REGRESSION_PIPELINE ));
133
+ String regressionModelId = "test_regression" ;
134
+ putModel (regressionModelId , REGRESSION_CONFIG );
135
+
136
+ client ().performRequest (putPipeline ("simple_classification_pipeline" ,
137
+ pipelineDefinition (classificationModelId , "classification" )));
138
+ client ().performRequest (putPipeline ("simple_regression_pipeline" , pipelineDefinition (regressionModelId , "regression" )));
104
139
105
140
for (int i = 0 ; i < 10 ; i ++) {
106
141
client ().performRequest (indexRequest ("index_for_inference_test" , "simple_classification_pipeline" , generateSourceDoc ()));
@@ -131,21 +166,30 @@ public void testPipelineIngest() throws Exception {
131
166
132
167
assertBusy (() -> {
133
168
try {
134
- Response statsResponse = client ().performRequest (new Request ("GET" , "_ml/inference/test_classification/_stats" ));
169
+ Response statsResponse = client ().performRequest (new Request ("GET" ,
170
+ "_ml/inference/" + classificationModelId + "/_stats" ));
135
171
assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :10" ));
136
- statsResponse = client ().performRequest (new Request ("GET" , "_ml/inference/test_regression /_stats" ));
172
+ statsResponse = client ().performRequest (new Request ("GET" , "_ml/inference/" + regressionModelId + " /_stats" ));
137
173
assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :15" ));
138
174
// can get both
139
175
statsResponse = client ().performRequest (new Request ("GET" , "_ml/inference/_stats" ));
140
- assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :15" ));
141
- assertThat (EntityUtils .toString (statsResponse .getEntity ()), containsString ("\" inference_count\" :10" ));
176
+ String entityString = EntityUtils .toString (statsResponse .getEntity ());
177
+ assertThat (entityString , containsString ("\" inference_count\" :15" ));
178
+ assertThat (entityString , containsString ("\" inference_count\" :10" ));
142
179
} catch (ResponseException ex ) {
143
180
//this could just mean shard failures.
181
+ fail (ex .getMessage ());
144
182
}
145
183
}, 30 , TimeUnit .SECONDS );
146
184
}
147
185
148
186
public void testSimulate () throws IOException {
187
+ String classificationModelId = "test_classification_simulate" ;
188
+ putModel (classificationModelId , CLASSIFICATION_CONFIG );
189
+
190
+ String regressionModelId = "test_regression_simulate" ;
191
+ putModel (regressionModelId , REGRESSION_CONFIG );
192
+
149
193
String source = "{\n " +
150
194
" \" pipeline\" : {\n " +
151
195
" \" processors\" : [\n " +
@@ -157,7 +201,7 @@ public void testSimulate() throws IOException {
157
201
" \" top_classes_results_field\" : \" result_class_prob\" ," +
158
202
" \" num_top_feature_importance_values\" : 2" +
159
203
" }},\n " +
160
- " \" model_id\" : \" test_classification \" ,\n " +
204
+ " \" model_id\" : \" " + classificationModelId + " \" ,\n " +
161
205
" \" field_map\" : {\n " +
162
206
" \" col1\" : \" col1\" ,\n " +
163
207
" \" col2\" : \" col2\" ,\n " +
@@ -169,7 +213,7 @@ public void testSimulate() throws IOException {
169
213
" {\n " +
170
214
" \" inference\" : {\n " +
171
215
" \" target_field\" : \" ml.regression\" ,\n " +
172
- " \" model_id\" : \" test_regression \" ,\n " +
216
+ " \" model_id\" : \" " + regressionModelId + " \" ,\n " +
173
217
" \" inference_config\" : {\" regression\" :{}},\n " +
174
218
" \" field_map\" : {\n " +
175
219
" \" col1\" : \" col1\" ,\n " +
@@ -232,6 +276,8 @@ public void testSimulate() throws IOException {
232
276
}
233
277
234
278
public void testSimulateWithDefaultMappedField () throws IOException {
279
+ String classificationModelId = "test_classification_default_mapped_field" ;
280
+ putModel (classificationModelId , CLASSIFICATION_CONFIG );
235
281
String source = "{\n " +
236
282
" \" pipeline\" : {\n " +
237
283
" \" processors\" : [\n " +
@@ -243,7 +289,7 @@ public void testSimulateWithDefaultMappedField() throws IOException {
243
289
" \" top_classes_results_field\" : \" result_class_prob\" ," +
244
290
" \" num_top_feature_importance_values\" : 2" +
245
291
" }},\n " +
246
- " \" model_id\" : \" test_classification \" ,\n " +
292
+ " \" model_id\" : \" " + classificationModelId + " \" ,\n " +
247
293
" \" field_map\" : {}\n " +
248
294
" }\n " +
249
295
" }\n " +
@@ -607,36 +653,28 @@ protected NamedXContentRegistry xContentRegistry() {
607
653
" \" definition\" : " + CLASSIFICATION_DEFINITION +
608
654
"}" ;
609
655
610
- private static final String CLASSIFICATION_PIPELINE = "{" +
611
- " \" processors\" : [\n " +
612
- " {\n " +
613
- " \" inference\" : {\n " +
614
- " \" model_id\" : \" test_classification\" ,\n " +
615
- " \" tag\" : \" classification\" ,\n " +
616
- " \" inference_config\" : {\" classification\" : {}},\n " +
617
- " \" field_map\" : {\n " +
618
- " \" col1\" : \" col1\" ,\n " +
619
- " \" col2\" : \" col2\" ,\n " +
620
- " \" col3\" : \" col3\" ,\n " +
621
- " \" col4\" : \" col4\" \n " +
622
- " }\n " +
623
- " }\n " +
624
- " }]}\n " ;
625
-
626
- private static final String REGRESSION_PIPELINE = "{" +
627
- " \" processors\" : [\n " +
628
- " {\n " +
629
- " \" inference\" : {\n " +
630
- " \" model_id\" : \" test_regression\" ,\n " +
631
- " \" tag\" : \" regression\" ,\n " +
632
- " \" inference_config\" : {\" regression\" : {}},\n " +
633
- " \" field_map\" : {\n " +
634
- " \" col1\" : \" col1\" ,\n " +
635
- " \" col2\" : \" col2\" ,\n " +
636
- " \" col3\" : \" col3\" ,\n " +
637
- " \" col4\" : \" col4\" \n " +
638
- " }\n " +
639
- " }\n " +
640
- " }]}\n " ;
656
+ private static String pipelineDefinition (String modelId , String inferenceConfig ) {
657
+ return "{" +
658
+ " \" processors\" : [\n " +
659
+ " {\n " +
660
+ " \" inference\" : {\n " +
661
+ " \" model_id\" : \" " + modelId + "\" ,\n " +
662
+ " \" tag\" : \" " + inferenceConfig + "\" ,\n " +
663
+ " \" inference_config\" : {\" " + inferenceConfig + "\" : {}},\n " +
664
+ " \" field_map\" : {\n " +
665
+ " \" col1\" : \" col1\" ,\n " +
666
+ " \" col2\" : \" col2\" ,\n " +
667
+ " \" col3\" : \" col3\" ,\n " +
668
+ " \" col4\" : \" col4\" \n " +
669
+ " }\n " +
670
+ " }\n " +
671
+ " }]}\n " ;
672
+ }
673
+
674
+ private void putModel (String modelId , String modelConfiguration ) throws IOException {
675
+ Request request = new Request ("PUT" , "_ml/inference/" + modelId );
676
+ request .setJsonEntity (modelConfiguration );
677
+ client ().performRequest (request );
678
+ }
641
679
642
680
}
0 commit comments