|
10 | 10 | import org.elasticsearch.client.RequestOptions;
|
11 | 11 | import org.elasticsearch.client.Response;
|
12 | 12 | import org.elasticsearch.client.ResponseException;
|
| 13 | +import org.elasticsearch.client.ml.GetTrainedModelsStatsResponse; |
| 14 | +import org.elasticsearch.client.ml.inference.TrainedModelStats; |
13 | 15 | import org.elasticsearch.common.bytes.BytesReference;
|
14 | 16 | import org.elasticsearch.common.settings.Settings;
|
15 | 17 | import org.elasticsearch.common.util.concurrent.ThreadContext;
|
16 | 18 | import org.elasticsearch.common.xcontent.NamedXContentRegistry;
|
17 | 19 | import org.elasticsearch.common.xcontent.XContentBuilder;
|
18 | 20 | import org.elasticsearch.common.xcontent.XContentFactory;
|
19 | 21 | import org.elasticsearch.common.xcontent.XContentHelper;
|
| 22 | +import org.elasticsearch.common.xcontent.XContentParser; |
20 | 23 | import org.elasticsearch.common.xcontent.XContentType;
|
| 24 | +import org.elasticsearch.common.xcontent.json.JsonXContent; |
21 | 25 | import org.elasticsearch.index.query.QueryBuilder;
|
22 | 26 | import org.elasticsearch.index.query.QueryBuilders;
|
23 | 27 | import org.elasticsearch.test.ExternalTestCluster;
|
|
38 | 42 | import java.util.concurrent.TimeUnit;
|
39 | 43 |
|
40 | 44 | import static org.hamcrest.CoreMatchers.containsString;
|
| 45 | +import static org.hamcrest.CoreMatchers.notNullValue; |
| 46 | +import static org.hamcrest.Matchers.equalTo; |
| 47 | +import static org.hamcrest.Matchers.greaterThan; |
| 48 | +import static org.hamcrest.Matchers.hasSize; |
| 49 | +import static org.hamcrest.Matchers.is; |
41 | 50 |
|
42 | 51 | /**
|
43 | 52 | * This is a {@link ESRestTestCase} because the cleanup code in {@link ExternalTestCluster#ensureEstimatedStats()} causes problems
|
@@ -134,15 +143,8 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
|
134 | 143 | assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
|
135 | 144 | assertBusy(() -> {
|
136 | 145 | try {
|
137 |
| - Response statsResponse = client().performRequest(new Request("GET", |
138 |
| - "_ml/trained_models/" + classificationModelId + "/_stats")); |
139 |
| - String response = EntityUtils.toString(statsResponse.getEntity()); |
140 |
| - assertThat(response, containsString("\"inference_count\":10")); |
141 |
| - assertThat(response, containsString("\"cache_miss_count\":30")); |
142 |
| - statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/" + regressionModelId + "/_stats")); |
143 |
| - response = EntityUtils.toString(statsResponse.getEntity()); |
144 |
| - assertThat(response, containsString("\"inference_count\":10")); |
145 |
| - assertThat(response, containsString("\"cache_miss_count\":30")); |
| 146 | + assertStatsWithCacheMisses(classificationModelId, 10L); |
| 147 | + assertStatsWithCacheMisses(regressionModelId, 10L); |
146 | 148 | } catch (ResponseException ex) {
|
147 | 149 | //this could just mean shard failures.
|
148 | 150 | fail(ex.getMessage());
|
@@ -190,27 +192,28 @@ public void testPipelineIngest() throws Exception {
|
190 | 192 |
|
191 | 193 | assertBusy(() -> {
|
192 | 194 | try {
|
193 |
| - Response statsResponse = client().performRequest(new Request("GET", |
194 |
| - "_ml/trained_models/" + classificationModelId + "/_stats")); |
195 |
| - String response = EntityUtils.toString(statsResponse.getEntity()); |
196 |
| - assertThat(response, containsString("\"inference_count\":10")); |
197 |
| - assertThat(response, containsString("\"cache_miss_count\":3")); |
198 |
| - statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/" + regressionModelId + "/_stats")); |
199 |
| - response = EntityUtils.toString(statsResponse.getEntity()); |
200 |
| - assertThat(response, containsString("\"inference_count\":15")); |
201 |
| - assertThat(response, containsString("\"cache_miss_count\":3")); |
202 |
| - // can get both |
203 |
| - statsResponse = client().performRequest(new Request("GET", "_ml/trained_models/_stats")); |
204 |
| - String entityString = EntityUtils.toString(statsResponse.getEntity()); |
205 |
| - assertThat(entityString, containsString("\"inference_count\":15")); |
206 |
| - assertThat(entityString, containsString("\"inference_count\":10")); |
| 195 | + assertStatsWithCacheMisses(classificationModelId, 10L); |
| 196 | + assertStatsWithCacheMisses(regressionModelId, 15L); |
207 | 197 | } catch (ResponseException ex) {
|
208 | 198 | //this could just mean shard failures.
|
209 | 199 | fail(ex.getMessage());
|
210 | 200 | }
|
211 | 201 | }, 30, TimeUnit.SECONDS);
|
212 | 202 | }
|
213 | 203 |
|
| 204 | + public void assertStatsWithCacheMisses(String modelId, long inferenceCount) throws IOException { |
| 205 | + Response statsResponse = client().performRequest(new Request("GET", |
| 206 | + "_ml/trained_models/" + modelId + "/_stats")); |
| 207 | + try (XContentParser parser = createParser(JsonXContent.jsonXContent, statsResponse.getEntity().getContent())) { |
| 208 | + GetTrainedModelsStatsResponse response = GetTrainedModelsStatsResponse.fromXContent(parser); |
| 209 | + assertThat(response.getTrainedModelStats(), hasSize(1)); |
| 210 | + TrainedModelStats trainedModelStats = response.getTrainedModelStats().get(0); |
| 211 | + assertThat(trainedModelStats.getInferenceStats(), is(notNullValue())); |
| 212 | + assertThat(trainedModelStats.getInferenceStats().getInferenceCount(), equalTo(inferenceCount)); |
| 213 | + assertThat(trainedModelStats.getInferenceStats().getCacheMissCount(), greaterThan(0L)); |
| 214 | + } |
| 215 | + } |
| 216 | + |
214 | 217 | public void testSimulate() throws IOException {
|
215 | 218 | String classificationModelId = "test_classification_simulate";
|
216 | 219 | putModel(classificationModelId, CLASSIFICATION_CONFIG);
|
|
0 commit comments