Skip to content

Commit 2c97e26

Browse files
authored
[ML][Data Frame] fix progress measurement for continuous transforms (#43838) (#43887)
* [ML][Data Frame] fix progress measurement for continuous transforms * Update DataFrameIndexer.java
1 parent 8755448 commit 2c97e26

File tree

5 files changed

+161
-61
lines changed

5 files changed

+161
-61
lines changed

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameGetAndGetStatsIT.java

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
import org.junit.Before;
1414

1515
import java.io.IOException;
16+
import java.time.Instant;
1617
import java.util.Collections;
1718
import java.util.List;
1819
import java.util.Map;
20+
import java.util.concurrent.TimeUnit;
1921

2022
import static org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue;
2123
import static org.hamcrest.Matchers.equalTo;
@@ -202,4 +204,86 @@ public void testGetProgressStatsWithPivotQuery() throws Exception {
202204
assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
203205
}
204206
}
207+
208+
@SuppressWarnings("unchecked")
209+
public void testGetProgressResetWithContinuous() throws Exception {
210+
String transformId = "pivot_progress_continuous";
211+
String transformDest = transformId + "_idx";
212+
String transformSrc = "reviews_cont_pivot_test";
213+
createReviewsIndex(transformSrc);
214+
final Request createDataframeTransformRequest = createRequestWithAuth("PUT", DATAFRAME_ENDPOINT + transformId, null);
215+
String config = "{ \"dest\": {\"index\":\"" + transformDest + "\"},"
216+
+ " \"source\": {\"index\":\"" + transformSrc + "\"},"
217+
+ " \"sync\": {\"time\":{\"field\": \"timestamp\", \"delay\": \"1s\"}},"
218+
+ " \"pivot\": {"
219+
+ " \"group_by\": {"
220+
+ " \"reviewer\": {"
221+
+ " \"terms\": {"
222+
+ " \"field\": \"user_id\""
223+
+ " } } },"
224+
+ " \"aggregations\": {"
225+
+ " \"avg_rating\": {"
226+
+ " \"avg\": {"
227+
+ " \"field\": \"stars\""
228+
+ " } } } }"
229+
+ "}";
230+
231+
createDataframeTransformRequest.setJsonEntity(config);
232+
233+
Map<String, Object> createDataframeTransformResponse = entityAsMap(client().performRequest(createDataframeTransformRequest));
234+
assertThat(createDataframeTransformResponse.get("acknowledged"), equalTo(Boolean.TRUE));
235+
startAndWaitForContinuousTransform(transformId, transformDest, null);
236+
237+
Request getRequest = createRequestWithAuth("GET", DATAFRAME_ENDPOINT + transformId + "/_stats", null);
238+
Map<String, Object> stats = entityAsMap(client().performRequest(getRequest));
239+
List<Map<String, Object>> transformsStats = (List<Map<String, Object>>)XContentMapValues.extractValue("transforms", stats);
240+
assertEquals(1, transformsStats.size());
241+
// Verify that the transform's progress
242+
for (Map<String, Object> transformStats : transformsStats) {
243+
Map<String, Object> progress = (Map<String, Object>)XContentMapValues.extractValue("state.progress", transformStats);
244+
assertThat("total_docs is not 1000", progress.get("total_docs"), equalTo(1000));
245+
assertThat("docs_remaining is not 0", progress.get("docs_remaining"), equalTo(0));
246+
assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
247+
}
248+
249+
// add more docs to verify total_docs gets updated with continuous
250+
int numDocs = 10;
251+
final StringBuilder bulk = new StringBuilder();
252+
long now = Instant.now().toEpochMilli() - 1_000;
253+
for (int i = 0; i < numDocs; i++) {
254+
bulk.append("{\"index\":{\"_index\":\"" + transformSrc + "\"}}\n")
255+
.append("{\"user_id\":\"")
256+
.append("user_")
257+
// Doing only new users so that there is a deterministic number of docs for progress
258+
.append(randomFrom(42, 47, 113))
259+
.append("\",\"business_id\":\"")
260+
.append("business_")
261+
.append(10)
262+
.append("\",\"stars\":")
263+
.append(5)
264+
.append(",\"timestamp\":")
265+
.append(now)
266+
.append("}\n");
267+
}
268+
bulk.append("\r\n");
269+
final Request bulkRequest = new Request("POST", "/_bulk");
270+
bulkRequest.addParameter("refresh", "true");
271+
bulkRequest.setJsonEntity(bulk.toString());
272+
client().performRequest(bulkRequest);
273+
274+
waitForDataFrameCheckpoint(transformId, 2L);
275+
276+
assertBusy(() -> {
277+
Map<String, Object> statsResponse = entityAsMap(client().performRequest(getRequest));
278+
List<Map<String, Object>> contStats = (List<Map<String, Object>>)XContentMapValues.extractValue("transforms", statsResponse);
279+
assertEquals(1, contStats.size());
280+
// add more docs to verify total_docs is the number of new docs added to the index
281+
for (Map<String, Object> transformStats : contStats) {
282+
Map<String, Object> progress = (Map<String, Object>)XContentMapValues.extractValue("state.progress", transformStats);
283+
assertThat("total_docs is not 10", progress.get("total_docs"), equalTo(numDocs));
284+
assertThat("docs_remaining is not 0", progress.get("docs_remaining"), equalTo(0));
285+
assertThat("percent_complete is not 100.0", progress.get("percent_complete"), equalTo(100.0));
286+
}
287+
}, 60, TimeUnit.SECONDS);
288+
}
205289
}

x-pack/plugin/data-frame/qa/single-node-tests/src/test/java/org/elasticsearch/xpack/dataframe/integration/DataFrameTransformProgressIT.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,9 @@ public void testGetProgress() throws Exception {
136136
null);
137137

138138
final RestHighLevelClient restClient = new TestRestHighLevelClient();
139-
SearchResponse response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
139+
SearchResponse response = restClient.search(
140+
TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
141+
RequestOptions.DEFAULT);
140142

141143
DataFrameTransformProgress progress =
142144
TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
@@ -157,7 +159,8 @@ public void testGetProgress() throws Exception {
157159
pivotConfig,
158160
null);
159161

160-
response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
162+
response = restClient.search(TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
163+
RequestOptions.DEFAULT);
161164
progress = TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
162165

163166
assertThat(progress.getTotalDocs(), equalTo(35L));
@@ -175,7 +178,8 @@ public void testGetProgress() throws Exception {
175178
pivotConfig,
176179
null);
177180

178-
response = restClient.search(TransformProgressGatherer.getSearchRequest(config), RequestOptions.DEFAULT);
181+
response = restClient.search(TransformProgressGatherer.getSearchRequest(config, config.getSource().getQueryConfig().getQuery()),
182+
RequestOptions.DEFAULT);
179183
progress = TransformProgressGatherer.searchResponseToDataFrameTransformProgressFunction().apply(response);
180184

181185
assertThat(progress.getTotalDocs(), equalTo(0L));

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameIndexer.java

Lines changed: 33 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -117,29 +117,7 @@ protected void onStart(long now, ActionListener<Void> listener) {
117117
if (pageSize == 0) {
118118
pageSize = pivot.getInitialPageSize();
119119
}
120-
121-
// if run for the 1st time, create checkpoint
122-
if (initialRun()) {
123-
createCheckpoint(ActionListener.wrap(cp -> {
124-
DataFrameTransformCheckpoint oldCheckpoint = inProgressOrLastCheckpoint;
125-
126-
if (oldCheckpoint.isEmpty()) {
127-
// this is the 1st run, accept the new in progress checkpoint and go on
128-
inProgressOrLastCheckpoint = cp;
129-
listener.onResponse(null);
130-
} else {
131-
logger.debug ("Getting changes from {} to {}", oldCheckpoint.getTimeUpperBound(), cp.getTimeUpperBound());
132-
133-
getChangedBuckets(oldCheckpoint, cp, ActionListener.wrap(changedBuckets -> {
134-
inProgressOrLastCheckpoint = cp;
135-
this.changedBuckets = changedBuckets;
136-
listener.onResponse(null);
137-
}, listener::onFailure));
138-
}
139-
}, listener::onFailure));
140-
} else {
141-
listener.onResponse(null);
142-
}
120+
listener.onResponse(null);
143121
} catch (Exception e) {
144122
listener.onFailure(e);
145123
}
@@ -151,8 +129,8 @@ protected boolean initialRun() {
151129

152130
@Override
153131
protected void onFinish(ActionListener<Void> listener) {
154-
// reset the page size, so we do not memorize a low page size forever, the pagesize will be re-calculated on start
155-
pageSize = 0;
132+
// reset the page size, so we do not memorize a low page size forever
133+
pageSize = pivot.getInitialPageSize();
156134
// reset the changed bucket to free memory
157135
changedBuckets = null;
158136
}
@@ -218,13 +196,7 @@ private Stream<IndexRequest> processBucketsToIndexRequests(CompositeAggregation
218196
});
219197
}
220198

221-
@Override
222-
protected SearchRequest buildSearchRequest() {
223-
SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex());
224-
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
225-
sourceBuilder.aggregation(pivot.buildAggregation(getPosition(), pageSize));
226-
sourceBuilder.size(0);
227-
199+
protected QueryBuilder buildFilterQuery() {
228200
QueryBuilder pivotQueryBuilder = getConfig().getSource().getQueryConfig().getQuery();
229201

230202
DataFrameTransformConfig config = getConfig();
@@ -233,9 +205,9 @@ protected SearchRequest buildSearchRequest() {
233205
throw new RuntimeException("in progress checkpoint not found");
234206
}
235207

236-
BoolQueryBuilder filteredQuery = new BoolQueryBuilder().
237-
filter(pivotQueryBuilder).
238-
filter(config.getSyncConfig().getRangeQuery(inProgressOrLastCheckpoint));
208+
BoolQueryBuilder filteredQuery = new BoolQueryBuilder()
209+
.filter(pivotQueryBuilder)
210+
.filter(config.getSyncConfig().getRangeQuery(inProgressOrLastCheckpoint));
239211

240212
if (changedBuckets != null && changedBuckets.isEmpty() == false) {
241213
QueryBuilder pivotFilter = pivot.filterBuckets(changedBuckets);
@@ -245,11 +217,19 @@ protected SearchRequest buildSearchRequest() {
245217
}
246218

247219
logger.trace("running filtered query: {}", filteredQuery);
248-
sourceBuilder.query(filteredQuery);
220+
return filteredQuery;
249221
} else {
250-
sourceBuilder.query(pivotQueryBuilder);
222+
return pivotQueryBuilder;
251223
}
224+
}
252225

226+
@Override
227+
protected SearchRequest buildSearchRequest() {
228+
SearchRequest searchRequest = new SearchRequest(getConfig().getSource().getIndex());
229+
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder()
230+
.aggregation(pivot.buildAggregation(getPosition(), pageSize))
231+
.size(0)
232+
.query(buildFilterQuery());
253233
searchRequest.source(sourceBuilder);
254234
return searchRequest;
255235
}
@@ -292,15 +272,24 @@ protected boolean handleCircuitBreakingException(Exception e) {
292272
return true;
293273
}
294274

295-
private void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint, DataFrameTransformCheckpoint newCheckpoint,
296-
ActionListener<Map<String, Set<String>>> listener) {
297-
275+
protected void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint,
276+
DataFrameTransformCheckpoint newCheckpoint,
277+
ActionListener<Map<String, Set<String>>> listener) {
278+
279+
ActionListener<Map<String, Set<String>>> wrappedListener = ActionListener.wrap(
280+
r -> {
281+
this.inProgressOrLastCheckpoint = newCheckpoint;
282+
this.changedBuckets = r;
283+
listener.onResponse(r);
284+
},
285+
listener::onFailure
286+
);
298287
// initialize the map of changed buckets, the map might be empty if source do not require/implement
299288
// changed bucket detection
300289
Map<String, Set<String>> keys = pivot.initialIncrementalBucketUpdateMap();
301290
if (keys.isEmpty()) {
302291
logger.trace("This data frame does not implement changed bucket detection, returning");
303-
listener.onResponse(null);
292+
wrappedListener.onResponse(null);
304293
return;
305294
}
306295

@@ -324,17 +313,17 @@ private void getChangedBuckets(DataFrameTransformCheckpoint oldCheckpoint, DataF
324313
sourceBuilder.query(filteredQuery);
325314
} else {
326315
logger.trace("No sync configured");
327-
listener.onResponse(null);
316+
wrappedListener.onResponse(null);
328317
return;
329318
}
330319

331320
searchRequest.source(sourceBuilder);
332321
searchRequest.allowPartialSearchResults(false);
333322

334-
collectChangedBuckets(searchRequest, changesAgg, keys, ActionListener.wrap(listener::onResponse, e -> {
323+
collectChangedBuckets(searchRequest, changesAgg, keys, ActionListener.wrap(wrappedListener::onResponse, e -> {
335324
// fall back if bucket collection failed
336325
logger.error("Failed to retrieve changed buckets, fall back to complete retrieval", e);
337-
listener.onResponse(null);
326+
wrappedListener.onResponse(null);
338327
}));
339328
}
340329

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/DataFrameTransformTask.java

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949

5050
import java.util.Arrays;
5151
import java.util.Map;
52+
import java.util.Set;
5253
import java.util.concurrent.CountDownLatch;
5354
import java.util.concurrent.TimeUnit;
5455
import java.util.concurrent.atomic.AtomicInteger;
@@ -523,17 +524,35 @@ protected void onStart(long now, ActionListener<Void> listener) {
523524
// Since multiple checkpoints can be executed in the task while it is running on the same node, we need to gather
524525
// the progress here, and not in the executor.
525526
if (initialRun()) {
526-
TransformProgressGatherer.getInitialProgress(this.client, getConfig(), ActionListener.wrap(
527-
newProgress -> {
528-
progress = newProgress;
529-
super.onStart(now, listener);
527+
ActionListener<Map<String, Set<String>>> changedBucketsListener = ActionListener.wrap(
528+
r -> {
529+
TransformProgressGatherer.getInitialProgress(this.client, buildFilterQuery(), getConfig(), ActionListener.wrap(
530+
newProgress -> {
531+
logger.trace("[{}] reset the progress from [{}] to [{}]", transformId, progress, newProgress);
532+
progress = newProgress;
533+
super.onStart(now, listener);
534+
},
535+
failure -> {
536+
progress = null;
537+
logger.warn("Unable to load progress information for task [" + transformId + "]", failure);
538+
super.onStart(now, listener);
539+
}
540+
));
530541
},
531-
failure -> {
532-
progress = null;
533-
logger.warn("Unable to load progress information for task [" + transformId + "]", failure);
534-
super.onStart(now, listener);
542+
listener::onFailure
543+
);
544+
545+
createCheckpoint(ActionListener.wrap(cp -> {
546+
DataFrameTransformCheckpoint oldCheckpoint = inProgressOrLastCheckpoint;
547+
if (oldCheckpoint.isEmpty()) {
548+
// this is the 1st run, accept the new in progress checkpoint and go on
549+
inProgressOrLastCheckpoint = cp;
550+
changedBucketsListener.onResponse(null);
551+
} else {
552+
logger.debug ("Getting changes from {} to {}", oldCheckpoint.getTimeUpperBound(), cp.getTimeUpperBound());
553+
getChangedBuckets(oldCheckpoint, cp, changedBucketsListener);
535554
}
536-
));
555+
}, listener::onFailure));
537556
} else {
538557
super.onStart(now, listener);
539558
}

x-pack/plugin/data-frame/src/main/java/org/elasticsearch/xpack/dataframe/transforms/TransformProgressGatherer.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.elasticsearch.action.search.SearchResponse;
1313
import org.elasticsearch.client.Client;
1414
import org.elasticsearch.index.query.BoolQueryBuilder;
15+
import org.elasticsearch.index.query.QueryBuilder;
1516
import org.elasticsearch.index.query.QueryBuilders;
1617
import org.elasticsearch.search.builder.SearchSourceBuilder;
1718
import org.elasticsearch.xpack.core.ClientHelper;
@@ -28,13 +29,16 @@ public final class TransformProgressGatherer {
2829
/**
2930
* This gathers the total docs given the config and search
3031
*
31-
* TODO: Support checkpointing logic to restrict the query
32-
* @param progressListener The listener to alert on completion
32+
* @param client ES Client to make queries
33+
* @param filterQuery The adapted filter that can optionally take into account checkpoint information
34+
* @param config The transform config containing headers, source, pivot, etc. information
35+
* @param progressListener The listener to notify when progress object has been created
3336
*/
3437
public static void getInitialProgress(Client client,
38+
QueryBuilder filterQuery,
3539
DataFrameTransformConfig config,
3640
ActionListener<DataFrameTransformProgress> progressListener) {
37-
SearchRequest request = getSearchRequest(config);
41+
SearchRequest request = getSearchRequest(config, filterQuery);
3842

3943
ActionListener<SearchResponse> searchResponseActionListener = ActionListener.wrap(
4044
searchResponse -> progressListener.onResponse(searchResponseToDataFrameTransformProgressFunction().apply(searchResponse)),
@@ -48,7 +52,7 @@ public static void getInitialProgress(Client client,
4852
searchResponseActionListener);
4953
}
5054

51-
public static SearchRequest getSearchRequest(DataFrameTransformConfig config) {
55+
public static SearchRequest getSearchRequest(DataFrameTransformConfig config, QueryBuilder filteredQuery) {
5256
SearchRequest request = new SearchRequest(config.getSource().getIndex());
5357
request.allowPartialSearchResults(false);
5458
BoolQueryBuilder existsClauses = QueryBuilders.boolQuery();
@@ -63,7 +67,7 @@ public static SearchRequest getSearchRequest(DataFrameTransformConfig config) {
6367
.size(0)
6468
.trackTotalHits(true)
6569
.query(QueryBuilders.boolQuery()
66-
.filter(config.getSource().getQueryConfig().getQuery())
70+
.filter(filteredQuery)
6771
.filter(existsClauses)));
6872
return request;
6973
}

0 commit comments

Comments
 (0)