Skip to content

Commit 82bae15

Browse files
[FEATURE][ML] Ensure data extractor is not leaking scroll contexts
1 parent a2268a2 commit 82bae15

File tree

2 files changed

+369
-22
lines changed

2 files changed

+369
-22
lines changed

x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/dataframe/extractor/DataFrameDataExtractor.java

+28-22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import org.apache.logging.log4j.LogManager;
99
import org.apache.logging.log4j.Logger;
10+
import org.apache.logging.log4j.message.ParameterizedMessage;
1011
import org.elasticsearch.action.search.ClearScrollAction;
1112
import org.elasticsearch.action.search.ClearScrollRequest;
1213
import org.elasticsearch.action.search.SearchAction;
@@ -20,7 +21,6 @@
2021
import org.elasticsearch.search.SearchHit;
2122
import org.elasticsearch.search.sort.SortOrder;
2223
import org.elasticsearch.xpack.core.ClientHelper;
23-
import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils;
2424
import org.elasticsearch.xpack.ml.datafeed.extractor.fields.ExtractedField;
2525
import org.elasticsearch.xpack.ml.dataframe.DataFrameAnalyticsFields;
2626

@@ -34,6 +34,7 @@
3434
import java.util.Objects;
3535
import java.util.Optional;
3636
import java.util.concurrent.TimeUnit;
37+
import java.util.function.Supplier;
3738
import java.util.stream.Collectors;
3839

3940
/**
@@ -91,9 +92,28 @@ public Optional<List<Row>> next() throws IOException {
9192

9293
protected List<Row> initScroll() throws IOException {
9394
LOGGER.debug("[{}] Initializing scroll", context.jobId);
94-
SearchResponse searchResponse = executeSearchRequest(buildSearchRequest());
95-
LOGGER.debug("[{}] Search response was obtained", context.jobId);
96-
return processSearchResponse(searchResponse);
95+
return tryRequestWithSearchResponse(() -> executeSearchRequest(buildSearchRequest()));
96+
}
97+
98+
private List<Row> tryRequestWithSearchResponse(Supplier<SearchResponse> request) throws IOException {
99+
try {
100+
// We've set allow_partial_search_results to false which means if something
101+
// goes wrong the request will throw.
102+
SearchResponse searchResponse = request.get();
103+
LOGGER.debug("[{}] Search response was obtained", context.jobId);
104+
105+
// Request was successful so we can restore the flag to retry if a future failure occurs
106+
searchHasShardFailure = false;
107+
108+
return processSearchResponse(searchResponse);
109+
} catch (Exception e) {
110+
if (searchHasShardFailure) {
111+
throw e;
112+
}
113+
LOGGER.warn(new ParameterizedMessage("[{}] Search resulted to failure; retrying once", context.jobId), e);
114+
markScrollAsErrored();
115+
return initScroll();
116+
}
97117
}
98118

99119
protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequestBuilder) {
@@ -103,6 +123,8 @@ protected SearchResponse executeSearchRequest(SearchRequestBuilder searchRequest
103123
private SearchRequestBuilder buildSearchRequest() {
104124
SearchRequestBuilder searchRequestBuilder = new SearchRequestBuilder(client, SearchAction.INSTANCE)
105125
.setScroll(SCROLL_TIMEOUT)
126+
// This ensures the search throws if there are failures and the scroll context gets cleared automatically
127+
.setAllowPartialSearchResults(false)
106128
.addSort(DataFrameAnalyticsFields.ID, SortOrder.ASC)
107129
.setIndices(context.indices)
108130
.setSize(context.scrollSize)
@@ -117,14 +139,6 @@ private SearchRequestBuilder buildSearchRequest() {
117139
}
118140

119141
private List<Row> processSearchResponse(SearchResponse searchResponse) throws IOException {
120-
121-
if (searchResponse.getFailedShards() > 0 && searchHasShardFailure == false) {
122-
LOGGER.debug("[{}] Resetting scroll search after shard failure", context.jobId);
123-
markScrollAsErrored();
124-
return initScroll();
125-
}
126-
127-
ExtractorUtils.checkSearchWasSuccessful(context.jobId, searchResponse);
128142
scrollId = searchResponse.getScrollId();
129143
if (searchResponse.getHits().getHits().length == 0) {
130144
hasNext = false;
@@ -143,7 +157,6 @@ private List<Row> processSearchResponse(SearchResponse searchResponse) throws IO
143157
rows.add(createRow(hit));
144158
}
145159
return rows;
146-
147160
}
148161

149162
private Row createRow(SearchHit hit) {
@@ -163,15 +176,13 @@ private Row createRow(SearchHit hit) {
163176

164177
private List<Row> continueScroll() throws IOException {
165178
LOGGER.debug("[{}] Continuing scroll with id [{}]", context.jobId, scrollId);
166-
SearchResponse searchResponse = executeSearchScrollRequest(scrollId);
167-
LOGGER.debug("[{}] Search response was obtained", context.jobId);
168-
return processSearchResponse(searchResponse);
179+
return tryRequestWithSearchResponse(() -> executeSearchScrollRequest(scrollId));
169180
}
170181

171182
private void markScrollAsErrored() {
172183
// This could be a transient error with the scroll Id.
173184
// Reinitialise the scroll and try again but only once.
174-
resetScroll();
185+
scrollId = null;
175186
searchHasShardFailure = true;
176187
}
177188

@@ -183,11 +194,6 @@ protected SearchResponse executeSearchScrollRequest(String scrollId) {
183194
.get());
184195
}
185196

186-
private void resetScroll() {
187-
clearScroll(scrollId);
188-
scrollId = null;
189-
}
190-
191197
private void clearScroll(String scrollId) {
192198
if (scrollId != null) {
193199
ClearScrollRequest request = new ClearScrollRequest();

0 commit comments

Comments
 (0)