Skip to content

Commit cabff65

Browse files
authored
[ML] Fixing inference stats race condition (#55163) (#55486)
`updateAndGet` could actually call the internal method more than once on contention. If I read the JavaDocs, it says: ```* @param updateFunction a side-effect-free function``` So, it could be getting multiple updates on contention, thus having a race condition where stats are double counted. To fix, I am going to use a `ReadWriteLock`. The `LongAdder` objects allows fast thread safe writes in high contention environments. These can be protected by the `ReadWriteLock::readLock`. When stats are persisted, I need to call reset on all these adders. This is NOT thread safe if additions are taking place concurrently. So, I am going to protect with `ReadWriteLock::writeLock`. This should prevent race conditions while allowing high (ish) throughput in the highly contention paths in inference. I did some simple throughput tests and this change is not significantly slower and is simpler to grok (IMO). closes #54786
1 parent 24d41eb commit cabff65

File tree

8 files changed

+243
-109
lines changed

8 files changed

+243
-109
lines changed

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceStats.java

+46-8
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.time.Instant;
2222
import java.util.Objects;
2323
import java.util.concurrent.atomic.LongAdder;
24+
import java.util.concurrent.locks.ReadWriteLock;
25+
import java.util.concurrent.locks.ReentrantReadWriteLock;
2426

2527
public class InferenceStats implements ToXContentObject, Writeable {
2628

@@ -204,6 +206,12 @@ public static class Accumulator {
204206
private final LongAdder failureCountAccumulator = new LongAdder();
205207
private final String modelId;
206208
private final String nodeId;
209+
// curious reader
210+
// you may be wondering why the lock set to the fair.
211+
// When `currentStatsAndReset` is called, we want it guaranteed that it will eventually execute.
212+
// If a ReadWriteLock is unfair, there are no such guarantees.
213+
// A call for the `writelock::lock` could pause indefinitely.
214+
private final ReadWriteLock readWriteLock = new ReentrantReadWriteLock(true);
207215

208216
public Accumulator(String modelId, String nodeId) {
209217
this.modelId = modelId;
@@ -226,22 +234,52 @@ public Accumulator merge(InferenceStats otherStats) {
226234
}
227235

228236
public Accumulator incMissingFields() {
229-
this.missingFieldsAccumulator.increment();
230-
return this;
237+
readWriteLock.readLock().lock();
238+
try {
239+
this.missingFieldsAccumulator.increment();
240+
return this;
241+
} finally {
242+
readWriteLock.readLock().unlock();
243+
}
231244
}
232245

233246
public Accumulator incInference() {
234-
this.inferenceAccumulator.increment();
235-
return this;
247+
readWriteLock.readLock().lock();
248+
try {
249+
this.inferenceAccumulator.increment();
250+
return this;
251+
} finally {
252+
readWriteLock.readLock().unlock();
253+
}
236254
}
237255

238256
public Accumulator incFailure() {
239-
this.failureCountAccumulator.increment();
240-
return this;
257+
readWriteLock.readLock().lock();
258+
try {
259+
this.failureCountAccumulator.increment();
260+
return this;
261+
} finally {
262+
readWriteLock.readLock().unlock();
263+
}
241264
}
242265

243-
public InferenceStats currentStats() {
244-
return currentStats(Instant.now());
266+
/**
267+
* Thread safe.
268+
*
269+
* Returns the current stats and resets the values of all the counters.
270+
* @return The current stats
271+
*/
272+
public InferenceStats currentStatsAndReset() {
273+
readWriteLock.writeLock().lock();
274+
try {
275+
InferenceStats stats = currentStats(Instant.now());
276+
this.missingFieldsAccumulator.reset();
277+
this.inferenceAccumulator.reset();
278+
this.failureCountAccumulator.reset();
279+
return stats;
280+
} finally {
281+
readWriteLock.writeLock().unlock();
282+
}
245283
}
246284

247285
public InferenceStats currentStats(Instant timeStamp) {

x-pack/plugin/ml/qa/native-multi-node-tests/src/test/java/org/elasticsearch/xpack/ml/integration/InferenceIngestIT.java

+91-53
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
import org.elasticsearch.test.ExternalTestCluster;
2323
import org.elasticsearch.test.SecuritySettingsSourceField;
2424
import org.elasticsearch.test.rest.ESRestTestCase;
25+
import org.elasticsearch.xpack.core.ml.MlStatsIndex;
2526
import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider;
27+
import org.elasticsearch.xpack.core.ml.inference.persistence.InferenceIndexConstants;
2628
import org.elasticsearch.xpack.core.ml.integration.MlRestTestStateCleaner;
2729
import org.junit.After;
2830
import org.junit.Before;
@@ -46,14 +48,16 @@ public class InferenceIngestIT extends ESRestTestCase {
4648
basicAuthHeaderValue("x_pack_rest_user", SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING);
4749

4850
@Before
49-
public void createBothModels() throws Exception {
50-
Request request = new Request("PUT", "_ml/inference/test_classification");
51-
request.setJsonEntity(CLASSIFICATION_CONFIG);
52-
client().performRequest(request);
53-
54-
request = new Request("PUT", "_ml/inference/test_regression");
55-
request.setJsonEntity(REGRESSION_CONFIG);
56-
client().performRequest(request);
51+
public void setup() throws Exception {
52+
Request loggingSettings = new Request("PUT", "_cluster/settings");
53+
loggingSettings.setJsonEntity("" +
54+
"{" +
55+
"\"transient\" : {\n" +
56+
" \"logger.org.elasticsearch.xpack.ml.inference\" : \"TRACE\"\n" +
57+
" }" +
58+
"}");
59+
client().performRequest(loggingSettings);
60+
client().performRequest(new Request("GET", "/_cluster/health?wait_for_status=green&timeout=30s"));
5761
}
5862

5963
@Override
@@ -64,19 +68,33 @@ protected Settings restClientSettings() {
6468
@After
6569
public void cleanUpData() throws Exception {
6670
new MlRestTestStateCleaner(logger, adminClient()).clearMlMetadata();
71+
client().performRequest(new Request("DELETE", InferenceIndexConstants.INDEX_PATTERN));
72+
client().performRequest(new Request("DELETE", MlStatsIndex.indexPattern()));
73+
Request loggingSettings = new Request("PUT", "_cluster/settings");
74+
loggingSettings.setJsonEntity("" +
75+
"{" +
76+
"\"transient\" : {\n" +
77+
" \"logger.org.elasticsearch.xpack.ml.inference\" : null\n" +
78+
" }" +
79+
"}");
80+
client().performRequest(loggingSettings);
6781
ESRestTestCase.waitForPendingTasks(adminClient());
68-
client().performRequest(new Request("DELETE", "_ml/inference/test_classification"));
69-
client().performRequest(new Request("DELETE", "_ml/inference/test_regression"));
7082
}
7183

7284
public void testPathologicalPipelineCreationAndDeletion() throws Exception {
85+
String classificationModelId = "test_pathological_classification";
86+
putModel(classificationModelId, CLASSIFICATION_CONFIG);
87+
88+
String regressionModelId = "test_pathological_regression";
89+
putModel(regressionModelId, REGRESSION_CONFIG);
7390

7491
for (int i = 0; i < 10; i++) {
75-
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
92+
client().performRequest(putPipeline("simple_classification_pipeline",
93+
pipelineDefinition(classificationModelId, "classification")));
7694
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
7795
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_classification_pipeline"));
7896

79-
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
97+
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
8098
client().performRequest(indexRequest("index_for_inference_test", "simple_regression_pipeline", generateSourceDoc()));
8199
client().performRequest(new Request("DELETE", "_ingest/pipeline/simple_regression_pipeline"));
82100
}
@@ -94,13 +112,30 @@ public void testPathologicalPipelineCreationAndDeletion() throws Exception {
94112
QueryBuilders.existsQuery("ml.inference.classification.predicted_value"))));
95113

96114
assertThat(EntityUtils.toString(searchResponse.getEntity()), containsString("\"value\":10"));
115+
assertBusy(() -> {
116+
try {
117+
Response statsResponse = client().performRequest(new Request("GET",
118+
"_ml/inference/" + classificationModelId + "/_stats"));
119+
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
120+
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
121+
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
122+
} catch (ResponseException ex) {
123+
//this could just mean shard failures.
124+
fail(ex.getMessage());
125+
}
126+
}, 30, TimeUnit.SECONDS);
97127
}
98128

99-
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/54786")
100129
public void testPipelineIngest() throws Exception {
130+
String classificationModelId = "test_classification";
131+
putModel(classificationModelId, CLASSIFICATION_CONFIG);
101132

102-
client().performRequest(putPipeline("simple_classification_pipeline", CLASSIFICATION_PIPELINE));
103-
client().performRequest(putPipeline("simple_regression_pipeline", REGRESSION_PIPELINE));
133+
String regressionModelId = "test_regression";
134+
putModel(regressionModelId, REGRESSION_CONFIG);
135+
136+
client().performRequest(putPipeline("simple_classification_pipeline",
137+
pipelineDefinition(classificationModelId, "classification")));
138+
client().performRequest(putPipeline("simple_regression_pipeline", pipelineDefinition(regressionModelId, "regression")));
104139

105140
for (int i = 0; i < 10; i++) {
106141
client().performRequest(indexRequest("index_for_inference_test", "simple_classification_pipeline", generateSourceDoc()));
@@ -131,21 +166,30 @@ public void testPipelineIngest() throws Exception {
131166

132167
assertBusy(() -> {
133168
try {
134-
Response statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_classification/_stats"));
169+
Response statsResponse = client().performRequest(new Request("GET",
170+
"_ml/inference/" + classificationModelId + "/_stats"));
135171
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
136-
statsResponse = client().performRequest(new Request("GET", "_ml/inference/test_regression/_stats"));
172+
statsResponse = client().performRequest(new Request("GET", "_ml/inference/" + regressionModelId + "/_stats"));
137173
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
138174
// can get both
139175
statsResponse = client().performRequest(new Request("GET", "_ml/inference/_stats"));
140-
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":15"));
141-
assertThat(EntityUtils.toString(statsResponse.getEntity()), containsString("\"inference_count\":10"));
176+
String entityString = EntityUtils.toString(statsResponse.getEntity());
177+
assertThat(entityString, containsString("\"inference_count\":15"));
178+
assertThat(entityString, containsString("\"inference_count\":10"));
142179
} catch (ResponseException ex) {
143180
//this could just mean shard failures.
181+
fail(ex.getMessage());
144182
}
145183
}, 30, TimeUnit.SECONDS);
146184
}
147185

148186
public void testSimulate() throws IOException {
187+
String classificationModelId = "test_classification_simulate";
188+
putModel(classificationModelId, CLASSIFICATION_CONFIG);
189+
190+
String regressionModelId = "test_regression_simulate";
191+
putModel(regressionModelId, REGRESSION_CONFIG);
192+
149193
String source = "{\n" +
150194
" \"pipeline\": {\n" +
151195
" \"processors\": [\n" +
@@ -157,7 +201,7 @@ public void testSimulate() throws IOException {
157201
" \"top_classes_results_field\": \"result_class_prob\"," +
158202
" \"num_top_feature_importance_values\": 2" +
159203
" }},\n" +
160-
" \"model_id\": \"test_classification\",\n" +
204+
" \"model_id\": \"" + classificationModelId + "\",\n" +
161205
" \"field_map\": {\n" +
162206
" \"col1\": \"col1\",\n" +
163207
" \"col2\": \"col2\",\n" +
@@ -169,7 +213,7 @@ public void testSimulate() throws IOException {
169213
" {\n" +
170214
" \"inference\": {\n" +
171215
" \"target_field\": \"ml.regression\",\n" +
172-
" \"model_id\": \"test_regression\",\n" +
216+
" \"model_id\": \"" + regressionModelId + "\",\n" +
173217
" \"inference_config\": {\"regression\":{}},\n" +
174218
" \"field_map\": {\n" +
175219
" \"col1\": \"col1\",\n" +
@@ -232,6 +276,8 @@ public void testSimulate() throws IOException {
232276
}
233277

234278
public void testSimulateWithDefaultMappedField() throws IOException {
279+
String classificationModelId = "test_classification_default_mapped_field";
280+
putModel(classificationModelId, CLASSIFICATION_CONFIG);
235281
String source = "{\n" +
236282
" \"pipeline\": {\n" +
237283
" \"processors\": [\n" +
@@ -243,7 +289,7 @@ public void testSimulateWithDefaultMappedField() throws IOException {
243289
" \"top_classes_results_field\": \"result_class_prob\"," +
244290
" \"num_top_feature_importance_values\": 2" +
245291
" }},\n" +
246-
" \"model_id\": \"test_classification\",\n" +
292+
" \"model_id\": \"" + classificationModelId + "\",\n" +
247293
" \"field_map\": {}\n" +
248294
" }\n" +
249295
" }\n"+
@@ -607,36 +653,28 @@ protected NamedXContentRegistry xContentRegistry() {
607653
" \"definition\": " + CLASSIFICATION_DEFINITION +
608654
"}";
609655

610-
private static final String CLASSIFICATION_PIPELINE = "{" +
611-
" \"processors\": [\n" +
612-
" {\n" +
613-
" \"inference\": {\n" +
614-
" \"model_id\": \"test_classification\",\n" +
615-
" \"tag\": \"classification\",\n" +
616-
" \"inference_config\": {\"classification\": {}},\n" +
617-
" \"field_map\": {\n" +
618-
" \"col1\": \"col1\",\n" +
619-
" \"col2\": \"col2\",\n" +
620-
" \"col3\": \"col3\",\n" +
621-
" \"col4\": \"col4\"\n" +
622-
" }\n" +
623-
" }\n" +
624-
" }]}\n";
625-
626-
private static final String REGRESSION_PIPELINE = "{" +
627-
" \"processors\": [\n" +
628-
" {\n" +
629-
" \"inference\": {\n" +
630-
" \"model_id\": \"test_regression\",\n" +
631-
" \"tag\": \"regression\",\n" +
632-
" \"inference_config\": {\"regression\": {}},\n" +
633-
" \"field_map\": {\n" +
634-
" \"col1\": \"col1\",\n" +
635-
" \"col2\": \"col2\",\n" +
636-
" \"col3\": \"col3\",\n" +
637-
" \"col4\": \"col4\"\n" +
638-
" }\n" +
639-
" }\n" +
640-
" }]}\n";
656+
private static String pipelineDefinition(String modelId, String inferenceConfig) {
657+
return "{" +
658+
" \"processors\": [\n" +
659+
" {\n" +
660+
" \"inference\": {\n" +
661+
" \"model_id\": \"" + modelId + "\",\n" +
662+
" \"tag\": \""+ inferenceConfig + "\",\n" +
663+
" \"inference_config\": {\"" + inferenceConfig + "\": {}},\n" +
664+
" \"field_map\": {\n" +
665+
" \"col1\": \"col1\",\n" +
666+
" \"col2\": \"col2\",\n" +
667+
" \"col3\": \"col3\",\n" +
668+
" \"col4\": \"col4\"\n" +
669+
" }\n" +
670+
" }\n" +
671+
" }]}\n";
672+
}
673+
674+
private void putModel(String modelId, String modelConfiguration) throws IOException {
675+
Request request = new Request("PUT", "_ml/inference/" + modelId);
676+
request.setJsonEntity(modelConfiguration);
677+
client().performRequest(request);
678+
}
641679

642680
}

0 commit comments

Comments
 (0)