diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java index c6bfcf13eb113..b227bf2af05eb 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/AbstractAsyncBulkByScrollAction.java @@ -112,6 +112,7 @@ public abstract class AbstractAsyncBulkByScrollAction, ScrollableHitSource.Hit, RequestWrapper> scriptApplier; + private int lastBatchSize; public AbstractAsyncBulkByScrollAction(BulkByScrollTask task, boolean needsSourceDocumentVersions, boolean needsSourceDocumentSeqNoAndPrimaryTerm, Logger logger, ParentTaskAssigningClient client, @@ -211,7 +212,8 @@ private BulkRequest buildBulk(Iterable docs) } protected ScrollableHitSource buildScrollableResultSource(BackoffPolicy backoffPolicy) { - return new ClientScrollableHitSource(logger, backoffPolicy, threadPool, worker::countSearchRetry, this::finishHim, client, + return new ClientScrollableHitSource(logger, backoffPolicy, threadPool, worker::countSearchRetry, + this::onScrollResponse, this::finishHim, client, mainRequest.getSearchRequest()); } @@ -235,19 +237,26 @@ public void start() { } try { startTime.set(System.nanoTime()); - scrollSource.start(response -> onScrollResponse(timeValueNanos(System.nanoTime()), 0, response)); + scrollSource.start(); } catch (Exception e) { finishHim(e); } } + void onScrollResponse(ScrollableHitSource.AsyncResponse asyncResponse) { + // lastBatchStartTime is essentially unused (see WorkerBulkByScrollTaskState.throttleWaitTime. Leaving it for now, since it seems + // like a bug? + onScrollResponse(new TimeValue(System.nanoTime()), this.lastBatchSize, asyncResponse); + } + /** * Process a scroll response. * @param lastBatchStartTime the time when the last batch started. Used to calculate the throttling delay. * @param lastBatchSize the size of the last batch. Used to calculate the throttling delay. - * @param response the scroll response to process + * @param asyncResponse the response to process from ScrollableHitSource */ - void onScrollResponse(TimeValue lastBatchStartTime, int lastBatchSize, ScrollableHitSource.Response response) { + void onScrollResponse(TimeValue lastBatchStartTime, int lastBatchSize, ScrollableHitSource.AsyncResponse asyncResponse) { + ScrollableHitSource.Response response = asyncResponse.response(); logger.debug("[{}]: got scroll response with [{}] hits", task.getId(), response.getHits().size()); if (task.isCancelled()) { logger.debug("[{}]: finishing early because the task was cancelled", task.getId()); @@ -274,7 +283,7 @@ protected void doRun() throws Exception { * It is important that the batch start time be calculated from here, scroll response to scroll response. That way the time * waiting on the scroll doesn't count against this batch in the throttle. */ - prepareBulkRequest(timeValueNanos(System.nanoTime()), response); + prepareBulkRequest(timeValueNanos(System.nanoTime()), asyncResponse); } @Override @@ -291,7 +300,8 @@ public void onFailure(Exception e) { * delay has been slept. Uses the generic thread pool because reindex is rare enough not to need its own thread pool and because the * thread may be blocked by the user script. */ - void prepareBulkRequest(TimeValue thisBatchStartTime, ScrollableHitSource.Response response) { + void prepareBulkRequest(TimeValue thisBatchStartTime, ScrollableHitSource.AsyncResponse asyncResponse) { + ScrollableHitSource.Response response = asyncResponse.response(); logger.debug("[{}]: preparing bulk request", task.getId()); if (task.isCancelled()) { logger.debug("[{}]: finishing early because the task was cancelled", task.getId()); @@ -316,18 +326,18 @@ void prepareBulkRequest(TimeValue thisBatchStartTime, ScrollableHitSource.Respon /* * If we noop-ed the entire batch then just skip to the next batch or the BulkRequest would fail validation. */ - startNextScroll(thisBatchStartTime, timeValueNanos(System.nanoTime()), 0); + notifyDone(thisBatchStartTime, asyncResponse, 0); return; } request.timeout(mainRequest.getTimeout()); request.waitForActiveShards(mainRequest.getWaitForActiveShards()); - sendBulkRequest(thisBatchStartTime, request); + sendBulkRequest(request, () -> notifyDone(thisBatchStartTime, asyncResponse, request.requests().size())); } /** * Send a bulk request, handling retries. */ - void sendBulkRequest(TimeValue thisBatchStartTime, BulkRequest request) { + void sendBulkRequest(BulkRequest request, Runnable onSuccess) { if (logger.isDebugEnabled()) { logger.debug("[{}]: sending [{}] entry, [{}] bulk request", task.getId(), request.requests().size(), new ByteSizeValue(request.estimatedSizeInBytes())); @@ -340,7 +350,7 @@ void sendBulkRequest(TimeValue thisBatchStartTime, BulkRequest request) { bulkRetry.withBackoff(client::bulk, request, new ActionListener() { @Override public void onResponse(BulkResponse response) { - onBulkResponse(thisBatchStartTime, response); + onBulkResponse(response, onSuccess); } @Override @@ -353,7 +363,7 @@ public void onFailure(Exception e) { /** * Processes bulk responses, accounting for failures. */ - void onBulkResponse(TimeValue thisBatchStartTime, BulkResponse response) { + void onBulkResponse(BulkResponse response, Runnable onSuccess) { try { List failures = new ArrayList<>(); Set destinationIndicesThisBatch = new HashSet<>(); @@ -401,28 +411,20 @@ void onBulkResponse(TimeValue thisBatchStartTime, BulkResponse response) { return; } - startNextScroll(thisBatchStartTime, timeValueNanos(System.nanoTime()), response.getItems().length); + onSuccess.run(); } catch (Exception t) { finishHim(t); } } - /** - * Start the next scroll request. - * - * @param lastBatchSize the number of requests sent in the last batch. This is used to calculate the throttling values which are applied - * when the scroll returns - */ - void startNextScroll(TimeValue lastBatchStartTime, TimeValue now, int lastBatchSize) { + void notifyDone(TimeValue thisBatchStartTime, ScrollableHitSource.AsyncResponse asyncResponse, int batchSize) { if (task.isCancelled()) { logger.debug("[{}]: finishing early because the task was cancelled", task.getId()); finishHim(null); return; } - TimeValue extraKeepAlive = worker.throttleWaitTime(lastBatchStartTime, now, lastBatchSize); - scrollSource.startNextScroll(extraKeepAlive, response -> { - onScrollResponse(lastBatchStartTime, lastBatchSize, response); - }); + this.lastBatchSize = batchSize; + asyncResponse.done(worker.throttleWaitTime(thisBatchStartTime, timeValueNanos(System.nanoTime()), batchSize)); } private void recordFailure(Failure failure, List failures) { diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/TransportReindexAction.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/TransportReindexAction.java index 4928e4fd01f26..9da6b8e17ccf5 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/TransportReindexAction.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/TransportReindexAction.java @@ -279,7 +279,8 @@ protected ScrollableHitSource buildScrollableResultSource(BackoffPolicy backoffP RemoteInfo remoteInfo = mainRequest.getRemoteInfo(); createdThreads = synchronizedList(new ArrayList<>()); RestClient restClient = buildRestClient(remoteInfo, mainAction.sslConfig, task.getId(), createdThreads); - return new RemoteScrollableHitSource(logger, backoffPolicy, threadPool, worker::countSearchRetry, this::finishHim, + return new RemoteScrollableHitSource(logger, backoffPolicy, threadPool, worker::countSearchRetry, + this::onScrollResponse, this::finishHim, restClient, remoteInfo.getQuery(), mainRequest.getSearchRequest()); } return super.buildScrollableResultSource(backoffPolicy); diff --git a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java index d729a4d9c3aa7..ef6a4ff41fe32 100644 --- a/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java +++ b/modules/reindex/src/main/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSource.java @@ -30,7 +30,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.Version; import org.elasticsearch.action.bulk.BackoffPolicy; -import org.elasticsearch.index.reindex.ScrollableHitSource; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.client.Request; import org.elasticsearch.client.ResponseException; @@ -44,9 +43,10 @@ import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentParseException; +import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.reindex.ScrollableHitSource; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.threadpool.ThreadPool; @@ -68,8 +68,9 @@ public class RemoteScrollableHitSource extends ScrollableHitSource { Version remoteVersion; public RemoteScrollableHitSource(Logger logger, BackoffPolicy backoffPolicy, ThreadPool threadPool, Runnable countSearchRetry, - Consumer fail, RestClient client, BytesReference query, SearchRequest searchRequest) { - super(logger, backoffPolicy, threadPool, countSearchRetry, fail); + Consumer onResponse, Consumer fail, + RestClient client, BytesReference query, SearchRequest searchRequest) { + super(logger, backoffPolicy, threadPool, countSearchRetry, onResponse, fail); this.query = query; this.searchRequest = searchRequest; this.client = client; diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java index 3d28ce3bcbc96..938ff47d60485 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/AsyncBulkByScrollActionTests.java @@ -57,6 +57,7 @@ import org.elasticsearch.client.FilterClient; import org.elasticsearch.client.ParentTaskAssigningClient; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.text.Text; @@ -81,6 +82,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import java.util.ArrayList; @@ -91,12 +93,14 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Delayed; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.stream.IntStream; import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; @@ -141,12 +145,12 @@ public void setupForTest() { expectedHeaders.clear(); expectedHeaders.put(randomSimpleString(random()), randomSimpleString(random())); - setupClient(new TestThreadPool(getTestName())); + threadPool = new TestThreadPool(getTestName()); + setupClient(threadPool); firstSearchRequest = new SearchRequest(); testRequest = new DummyAbstractBulkByScrollRequest(firstSearchRequest); listener = new PlainActionFuture<>(); scrollId = null; - threadPool = new TestThreadPool(getClass().getName()); taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); testTask = (BulkByScrollTask) taskManager.register("don'tcare", "hereeither", testRequest); testTask.setWorker(testRequest.getRequestsPerSecond(), null); @@ -206,11 +210,15 @@ public void testStartRetriesOnRejectionButFailsOnTooManyRejections() throws Exce } public void testStartNextScrollRetriesOnRejectionAndSucceeds() throws Exception { + // this test primarily tests ClientScrollableHitSource but left it to test integration to status client.scrollsToReject = randomIntBetween(0, testRequest.getMaxRetries() - 1); - DummyAsyncBulkByScrollAction action = new DummyActionWithoutBackoff(); - action.setScroll(scrollId()); - TimeValue now = timeValueNanos(System.nanoTime()); - action.startNextScroll(now, now, 0); + // use fail() onResponse handler because mocked search never fires on listener. + ClientScrollableHitSource hitSource = new ClientScrollableHitSource(logger, buildTestBackoffPolicy(), + threadPool, + testTask.getWorkerState()::countSearchRetry, r -> fail(), ExceptionsHelper::reThrowIfNotNull, + new ParentTaskAssigningClient(client, localNode, testTask), testRequest.getSearchRequest()); + hitSource.setScroll(scrollId()); + hitSource.startNextScroll(TimeValue.timeValueSeconds(0)); assertBusy(() -> assertEquals(client.scrollsToReject + 1, client.scrollAttempts.get())); if (listener.isDone()) { Object result = listener.get(); @@ -221,15 +229,23 @@ public void testStartNextScrollRetriesOnRejectionAndSucceeds() throws Exception } public void testStartNextScrollRetriesOnRejectionButFailsOnTooManyRejections() throws Exception { + // this test primarily tests ClientScrollableHitSource but left it to test integration to status client.scrollsToReject = testRequest.getMaxRetries() + randomIntBetween(1, 100); - DummyAsyncBulkByScrollAction action = new DummyActionWithoutBackoff(); - action.setScroll(scrollId()); - TimeValue now = timeValueNanos(System.nanoTime()); - action.startNextScroll(now, now, 0); - assertBusy(() -> assertEquals(testRequest.getMaxRetries() + 1, client.scrollAttempts.get())); - assertBusy(() -> assertTrue(listener.isDone())); - ExecutionException e = expectThrows(ExecutionException.class, () -> listener.get()); - assertThat(ExceptionsHelper.stackTrace(e), containsString(EsRejectedExecutionException.class.getSimpleName())); + assertExactlyOnce( + onFail -> { + Consumer validingOnFail = e -> { + assertNotNull(ExceptionsHelper.unwrap(e, EsRejectedExecutionException.class)); + onFail.run(); + }; + ClientScrollableHitSource hitSource = new ClientScrollableHitSource(logger, buildTestBackoffPolicy(), + threadPool, + testTask.getWorkerState()::countSearchRetry, r -> fail(), validingOnFail, + new ParentTaskAssigningClient(client, localNode, testTask), testRequest.getSearchRequest()); + hitSource.setScroll(scrollId()); + hitSource.startNextScroll(TimeValue.timeValueSeconds(0)); + assertBusy(() -> assertEquals(testRequest.getMaxRetries() + 1, client.scrollAttempts.get())); + } + ); assertNull("There shouldn't be a scroll attempt pending that we didn't reject", client.lastScroll.get()); assertEquals(testRequest.getMaxRetries(), testTask.getStatus().getSearchRetries()); } @@ -261,7 +277,7 @@ public void testScrollResponseBatchingBehavior() throws Exception { } } - public void testBulkResponseSetsLotsOfStatus() { + public void testBulkResponseSetsLotsOfStatus() throws Exception { testRequest.setAbortOnVersionConflict(false); int maxBatches = randomIntBetween(0, 100); long versionConflicts = 0; @@ -306,7 +322,10 @@ public void testBulkResponseSetsLotsOfStatus() { new IndexResponse(shardId, "type", "id" + i, seqNo, primaryTerm, randomInt(), createdResponse); responses[i] = new BulkItemResponse(i, opType, response); } - new DummyAsyncBulkByScrollAction().onBulkResponse(timeValueNanos(System.nanoTime()), new BulkResponse(responses, 0)); + assertExactlyOnce(onSuccess -> + new DummyAsyncBulkByScrollAction().onBulkResponse(new BulkResponse(responses, 0), + onSuccess) + ); assertEquals(versionConflicts, testTask.getStatus().getVersionConflicts()); assertEquals(updated, testTask.getStatus().getUpdated()); assertEquals(created, testTask.getStatus().getCreated()); @@ -385,7 +404,7 @@ public void testBulkFailuresAbortRequest() throws Exception { DummyAsyncBulkByScrollAction action = new DummyAsyncBulkByScrollAction(); BulkResponse bulkResponse = new BulkResponse(new BulkItemResponse[] {new BulkItemResponse(0, DocWriteRequest.OpType.CREATE, failure)}, randomLong()); - action.onBulkResponse(timeValueNanos(System.nanoTime()), bulkResponse); + action.onBulkResponse(bulkResponse, Assert::fail); BulkByScrollResponse response = listener.get(); assertThat(response.getBulkFailures(), contains(failure)); assertThat(response.getSearchFailures(), empty()); @@ -444,11 +463,38 @@ public void testScrollDelay() throws Exception { public ScheduledCancellable schedule(Runnable command, TimeValue delay, String name) { capturedDelay.set(delay); capturedCommand.set(command); - return null; + return new ScheduledCancellable() { + private boolean cancelled = false; + @Override + public long getDelay(TimeUnit unit) { + return unit.convert(delay.millis(), TimeUnit.MILLISECONDS); + } + + @Override + public int compareTo(Delayed o) { + return 0; + } + + @Override + public boolean cancel() { + cancelled = true; + return true; + } + + @Override + public boolean isCancelled() { + return cancelled; + } + }; } }); - DummyAsyncBulkByScrollAction action = new DummyAsyncBulkByScrollAction(); + DummyAsyncBulkByScrollAction action = new DummyAsyncBulkByScrollAction() { + @Override + protected RequestWrapper buildRequest(Hit doc) { + return wrap(new IndexRequest().index("test")); + } + }; action.setScroll(scrollId()); // Set the base for the scroll to wait - this is added to the figure we calculate below @@ -456,21 +502,25 @@ public ScheduledCancellable schedule(Runnable command, TimeValue delay, String n // Set throttle to 1 request per second to make the math simpler worker.rethrottle(1f); - // Make the last batch look nearly instant but have 100 documents - TimeValue lastBatchStartTime = timeValueNanos(System.nanoTime()); - TimeValue now = timeValueNanos(lastBatchStartTime.nanos() + 1); - action.startNextScroll(lastBatchStartTime, now, 100); + action.start(); + + // create a simulated response. + SearchHit hit = new SearchHit(0, "id", new Text("type"), emptyMap()).sourceRef(new BytesArray("{}")); + SearchHits hits = new SearchHits(IntStream.range(0, 100).mapToObj(i -> hit).toArray(SearchHit[]::new), + new TotalHits(0, TotalHits.Relation.EQUAL_TO),0); + InternalSearchResponse internalResponse = new InternalSearchResponse(hits, null, null, null, false, false, 1); + SearchResponse searchResponse = new SearchResponse(internalResponse, scrollId(), 5, 4, 0, randomLong(), null, + SearchResponse.Clusters.EMPTY); + + client.lastSearch.get().listener.onResponse(searchResponse); + + assertEquals(0, capturedDelay.get().seconds()); + capturedCommand.get().run(); // So the next request is going to have to wait an extra 100 seconds or so (base was 10 seconds, so 110ish) assertThat(client.lastScroll.get().request.scroll().keepAlive().seconds(), either(equalTo(110L)).or(equalTo(109L))); // Now we can simulate a response and check the delay that we used for the task - SearchHit hit = new SearchHit(0, "id", new Text("type"), emptyMap()); - SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0); - InternalSearchResponse internalResponse = new InternalSearchResponse(hits, null, null, null, false, false, 1); - SearchResponse searchResponse = new SearchResponse(internalResponse, scrollId(), 5, 4, 0, randomLong(), null, - SearchResponse.Clusters.EMPTY); - if (randomBoolean()) { client.lastScroll.get().listener.onResponse(searchResponse); assertEquals(99, capturedDelay.get().seconds()); @@ -497,30 +547,22 @@ private void bulkRetryTestCase(boolean failWithRejection) throws Exception { testRequest.setMaxRetries(totalFailures - (failWithRejection ? 1 : 0)); client.bulksToReject = client.bulksAttempts.get() + totalFailures; - /* - * When we get a successful bulk response we usually start the next scroll request but lets just intercept that so we don't have to - * deal with it. We just wait for it to happen. - */ - CountDownLatch successLatch = new CountDownLatch(1); - DummyAsyncBulkByScrollAction action = new DummyActionWithoutBackoff() { - @Override - void startNextScroll(TimeValue lastBatchStartTime, TimeValue now, int lastBatchSize) { - successLatch.countDown(); - } - }; + DummyAsyncBulkByScrollAction action = new DummyActionWithoutBackoff(); BulkRequest request = new BulkRequest(); for (int i = 0; i < size + 1; i++) { request.add(new IndexRequest("index", "type", "id" + i)); } - action.sendBulkRequest(timeValueNanos(System.nanoTime()), request); if (failWithRejection) { + action.sendBulkRequest(request, Assert::fail); BulkByScrollResponse response = listener.get(); assertThat(response.getBulkFailures(), hasSize(1)); assertEquals(response.getBulkFailures().get(0).getStatus(), RestStatus.TOO_MANY_REQUESTS); assertThat(response.getSearchFailures(), empty()); assertNull(response.getReasonCancelled()); } else { - assertTrue(successLatch.await(10, TimeUnit.SECONDS)); + assertExactlyOnce(onSuccess -> + action.sendBulkRequest(request, onSuccess) + ); } } @@ -584,17 +626,17 @@ public void testCancelBeforeScrollResponse() throws Exception { public void testCancelBeforeSendBulkRequest() throws Exception { cancelTaskCase((DummyAsyncBulkByScrollAction action) -> - action.sendBulkRequest(timeValueNanos(System.nanoTime()), new BulkRequest())); + action.sendBulkRequest(new BulkRequest(), Assert::fail)); } public void testCancelBeforeOnBulkResponse() throws Exception { cancelTaskCase((DummyAsyncBulkByScrollAction action) -> - action.onBulkResponse(timeValueNanos(System.nanoTime()), new BulkResponse(new BulkItemResponse[0], 0))); + action.onBulkResponse(new BulkResponse(new BulkItemResponse[0], 0), Assert::fail)); } public void testCancelBeforeStartNextScroll() throws Exception { TimeValue now = timeValueNanos(System.nanoTime()); - cancelTaskCase((DummyAsyncBulkByScrollAction action) -> action.startNextScroll(now, now, 0)); + cancelTaskCase((DummyAsyncBulkByScrollAction action) -> action.notifyDone(now, null, 0)); } public void testCancelBeforeRefreshAndFinish() throws Exception { @@ -674,7 +716,17 @@ private void cancelTaskCase(Consumer testMe) throw private void simulateScrollResponse(DummyAsyncBulkByScrollAction action, TimeValue lastBatchTime, int lastBatchSize, ScrollableHitSource.Response response) { action.setScroll(scrollId()); - action.onScrollResponse(lastBatchTime, lastBatchSize, response); + action.onScrollResponse(lastBatchTime, lastBatchSize, new ScrollableHitSource.AsyncResponse() { + @Override + public ScrollableHitSource.Response response() { + return response; + } + + @Override + public void done(TimeValue extraKeepAlive) { + fail(); + } + }); } private class DummyAsyncBulkByScrollAction @@ -696,11 +748,15 @@ protected AbstractAsyncBulkByScrollAction.RequestWrapper buildRequest(Hit doc private class DummyActionWithoutBackoff extends DummyAsyncBulkByScrollAction { @Override BackoffPolicy buildBackoffPolicy() { - // Force a backoff time of 0 to prevent sleeping - return constantBackoff(timeValueMillis(0), testRequest.getMaxRetries()); + return buildTestBackoffPolicy(); } } + private BackoffPolicy buildTestBackoffPolicy() { + // Force a backoff time of 0 to prevent sleeping + return constantBackoff(timeValueMillis(0), testRequest.getMaxRetries()); + } + private static class DummyTransportAsyncBulkByScrollAction extends TransportAction { @@ -887,4 +943,13 @@ private static class RequestAndListener this.listener = listener; } } + + /** + * Assert that calling the consumer invokes the runnable exactly once. + */ + private void assertExactlyOnce(CheckedConsumer consumer) throws Exception { + AtomicBoolean called = new AtomicBoolean(); + consumer.accept(() -> assertTrue(called.compareAndSet(false, true))); + assertBusy(() -> assertTrue(called.get())); + } } diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ClientScrollableHitSourceTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ClientScrollableHitSourceTests.java new file mode 100644 index 0000000000000..37425a7c600ef --- /dev/null +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/ClientScrollableHitSourceTests.java @@ -0,0 +1,270 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.index.reindex; + +import org.apache.lucene.search.TotalHits; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.bulk.BackoffPolicy; +import org.elasticsearch.action.search.SearchAction; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.SearchScrollAction; +import org.elasticsearch.action.search.SearchScrollRequest; +import org.elasticsearch.client.ParentTaskAssigningClient; +import org.elasticsearch.client.support.AbstractClient; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.internal.InternalSearchResponse; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; + +import java.util.List; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.IntStream; + +import static java.util.Collections.emptyMap; +import static org.apache.lucene.util.TestUtil.randomSimpleString; +import static org.elasticsearch.common.unit.TimeValue.timeValueSeconds; +import static org.hamcrest.Matchers.instanceOf; + +public class ClientScrollableHitSourceTests extends ESTestCase { + + private ThreadPool threadPool; + + @Before + public void setUpThreadPool() { + threadPool = new TestThreadPool(getTestName()); + } + + @After + public void tearDownThreadPool() { + terminate(threadPool); + } + + // ensure we test the happy path on every build. + public void testStartScrollDone() throws InterruptedException { + dotestBasicsWithRetry(0, 0, 0, e -> fail()); + } + + public void testRetrySuccess() throws InterruptedException { + int retries = randomIntBetween(1, 10); + dotestBasicsWithRetry(retries, 0, retries, e -> fail()); + } + + private static class ExpectedException extends RuntimeException { + ExpectedException(Throwable cause) { + super(cause); + } + } + + public void testRetryFail() { + int retries = randomInt(10); + ExpectedException ex = expectThrows(ExpectedException.class, () -> { + dotestBasicsWithRetry(retries, retries+1, retries+1, e -> { throw new ExpectedException(e); }); + }); + assertThat(ex.getCause(), instanceOf(EsRejectedExecutionException.class)); + } + + private void dotestBasicsWithRetry(int retries, int minFailures, int maxFailures, + Consumer failureHandler) throws InterruptedException { + BlockingQueue responses = new ArrayBlockingQueue<>(100); + MockClient client = new MockClient(threadPool); + TaskId parentTask = new TaskId("thenode", randomInt()); + AtomicInteger actualSearchRetries = new AtomicInteger(); + int expectedSearchRetries = 0; + ClientScrollableHitSource hitSource = new ClientScrollableHitSource(logger, BackoffPolicy.constantBackoff(TimeValue.ZERO, retries), + threadPool, actualSearchRetries::incrementAndGet, responses::add, failureHandler, + new ParentTaskAssigningClient(client, parentTask), + new SearchRequest().scroll("1m")); + + hitSource.start(); + for (int retry = 0; retry < randomIntBetween(minFailures, maxFailures); ++retry) { + client.fail(SearchAction.INSTANCE, new EsRejectedExecutionException()); + client.awaitOperation(); + ++expectedSearchRetries; + } + + SearchResponse searchResponse = createSearchResponse(); + client.respond(SearchAction.INSTANCE, searchResponse); + + for (int i = 0; i < randomIntBetween(1, 10); ++i) { + ScrollableHitSource.AsyncResponse asyncResponse = responses.poll(10, TimeUnit.SECONDS); + assertNotNull(asyncResponse); + assertEquals(responses.size(), 0); + assertSameHits(asyncResponse.response().getHits(), searchResponse.getHits().getHits()); + asyncResponse.done(TimeValue.ZERO); + + for (int retry = 0; retry < randomIntBetween(minFailures, maxFailures); ++retry) { + client.fail(SearchScrollAction.INSTANCE, new EsRejectedExecutionException()); + client.awaitOperation(); + ++expectedSearchRetries; + } + + searchResponse = createSearchResponse(); + client.respond(SearchScrollAction.INSTANCE, searchResponse); + } + + assertEquals(actualSearchRetries.get(), expectedSearchRetries); + } + + public void testScrollKeepAlive() { + MockClient client = new MockClient(threadPool); + TaskId parentTask = new TaskId("thenode", randomInt()); + + ClientScrollableHitSource hitSource = new ClientScrollableHitSource(logger, BackoffPolicy.constantBackoff(TimeValue.ZERO, 0), + threadPool, () -> fail(), r -> fail(), e -> fail(), new ParentTaskAssigningClient(client, + parentTask), + // Set the base for the scroll to wait - this is added to the figure we calculate below + new SearchRequest().scroll(timeValueSeconds(10))); + + hitSource.startNextScroll(timeValueSeconds(100)); + client.validateRequest(SearchScrollAction.INSTANCE, + (SearchScrollRequest r) -> assertEquals(r.scroll().keepAlive().seconds(), 110)); + } + + + + private SearchResponse createSearchResponse() { + // create a simulated response. + SearchHit hit = new SearchHit(0, "id", new Text("type"), emptyMap()).sourceRef(new BytesArray("{}")); + SearchHits hits = new SearchHits(IntStream.range(0, randomIntBetween(0, 20)).mapToObj(i -> hit).toArray(SearchHit[]::new), + new TotalHits(0, TotalHits.Relation.EQUAL_TO),0); + InternalSearchResponse internalResponse = new InternalSearchResponse(hits, null, null, null, false, false, 1); + return new SearchResponse(internalResponse, randomSimpleString(random(), 1, 10), 5, 4, 0, randomLong(), null, + SearchResponse.Clusters.EMPTY); + } + + private void assertSameHits(List actual, SearchHit[] expected) { + assertEquals(actual.size(), expected.length); + for (int i = 0; i < actual.size(); ++i) { + assertEquals(actual.get(i).getSource(), expected[i].getSourceRef()); + assertEquals(actual.get(i).getIndex(), expected[i].getIndex()); + assertEquals(actual.get(i).getVersion(), expected[i].getVersion()); + assertEquals(actual.get(i).getPrimaryTerm(), expected[i].getPrimaryTerm()); + assertEquals(actual.get(i).getSeqNo(), expected[i].getSeqNo()); + assertEquals(actual.get(i).getId(), expected[i].getId()); + assertEquals(actual.get(i).getIndex(), expected[i].getIndex()); + } + } + + private static class ExecuteRequest { + private final ActionType action; + private final Request request; + private final ActionListener listener; + + ExecuteRequest(ActionType action, Request request, ActionListener listener) { + this.action = action; + this.request = request; + this.listener = listener; + } + + public void respond(ActionType action, Function response) { + assertEquals(action, this.action); + listener.onResponse(response.apply(request)); + } + + public void fail(ActionType action, Exception response) { + assertEquals(action, this.action); + listener.onFailure(response); + } + + public void validateRequest(ActionType action, Consumer validator) { + assertEquals(action, this.action); + validator.accept(request); + } + } + + private static class MockClient extends AbstractClient { + private ExecuteRequest executeRequest; + + MockClient(ThreadPool threadPool) { + super(Settings.EMPTY, threadPool); + } + + @Override + protected synchronized + void doExecute(ActionType action, + Request request, ActionListener listener) { + + this.executeRequest = new ExecuteRequest<>(action, request, listener); + this.notifyAll(); + } + + @SuppressWarnings("unchecked") + public void respondx(ActionType action, + Function response) { + ExecuteRequest executeRequest; + synchronized (this) { + executeRequest = this.executeRequest; + this.executeRequest = null; + } + ((ExecuteRequest) executeRequest).respond(action, response); + } + + public void respond(ActionType action, + Response response) { + respondx(action, req -> response); + } + + @SuppressWarnings("unchecked") + public void fail(ActionType action, Exception response) { + ExecuteRequest executeRequest; + synchronized (this) { + executeRequest = this.executeRequest; + this.executeRequest = null; + } + ((ExecuteRequest) executeRequest).fail(action, response); + } + + @SuppressWarnings("unchecked") + public void validateRequest(ActionType action, + Consumer validator) { + ((ExecuteRequest) executeRequest).validateRequest(action, validator); + } + + @Override + public void close() { + } + + public synchronized void awaitOperation() throws InterruptedException { + if (executeRequest == null) { + wait(10000); + assertNotNull("Must receive next request within 10s", executeRequest); + } + } + } +} diff --git a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSourceTests.java b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSourceTests.java index 0ab100a856fc1..68cc0effc1d91 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSourceTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/index/reindex/remote/RemoteScrollableHitSourceTests.java @@ -572,7 +572,7 @@ private void failRequest(Throwable t) { private class TestRemoteScrollableHitSource extends RemoteScrollableHitSource { TestRemoteScrollableHitSource(RestClient client) { super(RemoteScrollableHitSourceTests.this.logger, backoff(), RemoteScrollableHitSourceTests.this.threadPool, - RemoteScrollableHitSourceTests.this::countRetry, RemoteScrollableHitSourceTests.this::failRequest, client, + RemoteScrollableHitSourceTests.this::countRetry, r -> fail(), RemoteScrollableHitSourceTests.this::failRequest, client, new BytesArray("{}"), RemoteScrollableHitSourceTests.this.searchRequest); } } diff --git a/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java b/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java index bbc12fb2c2aea..22ee7835af1e3 100644 --- a/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java +++ b/server/src/main/java/org/elasticsearch/index/reindex/ClientScrollableHitSource.java @@ -61,8 +61,9 @@ public class ClientScrollableHitSource extends ScrollableHitSource { private final SearchRequest firstSearchRequest; public ClientScrollableHitSource(Logger logger, BackoffPolicy backoffPolicy, ThreadPool threadPool, Runnable countSearchRetry, - Consumer fail, ParentTaskAssigningClient client, SearchRequest firstSearchRequest) { - super(logger, backoffPolicy, threadPool, countSearchRetry, fail); + Consumer onResponse, Consumer fail, + ParentTaskAssigningClient client, SearchRequest firstSearchRequest) { + super(logger, backoffPolicy, threadPool, countSearchRetry, onResponse, fail); this.client = client; this.firstSearchRequest = firstSearchRequest; } diff --git a/server/src/main/java/org/elasticsearch/index/reindex/ScrollableHitSource.java b/server/src/main/java/org/elasticsearch/index/reindex/ScrollableHitSource.java index dc8d69ff4f0d1..2620b4d524da1 100644 --- a/server/src/main/java/org/elasticsearch/index/reindex/ScrollableHitSource.java +++ b/server/src/main/java/org/elasticsearch/index/reindex/ScrollableHitSource.java @@ -39,13 +39,16 @@ import java.io.IOException; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static java.util.Objects.requireNonNull; /** - * A scrollable source of results. + * A scrollable source of results. Pumps data out into the passed onResponse consumer. Same data may come out several times in case + * of failures during searching (though not yet). Once the onResponse consumer is done, it should call AsyncResponse.isDone(time) to receive + * more data (only receives one response at a time). */ public abstract class ScrollableHitSource { private final AtomicReference scrollId = new AtomicReference<>(); @@ -54,34 +57,53 @@ public abstract class ScrollableHitSource { protected final BackoffPolicy backoffPolicy; protected final ThreadPool threadPool; protected final Runnable countSearchRetry; + private final Consumer onResponse; protected final Consumer fail; public ScrollableHitSource(Logger logger, BackoffPolicy backoffPolicy, ThreadPool threadPool, Runnable countSearchRetry, - Consumer fail) { + Consumer onResponse, Consumer fail) { this.logger = logger; this.backoffPolicy = backoffPolicy; this.threadPool = threadPool; this.countSearchRetry = countSearchRetry; + this.onResponse = onResponse; this.fail = fail; } - public final void start(Consumer onResponse) { + public final void start() { doStart(response -> { setScroll(response.getScrollId()); logger.debug("scroll returned [{}] documents with a scroll id of [{}]", response.getHits().size(), response.getScrollId()); - onResponse.accept(response); + onResponse(response); }); } protected abstract void doStart(Consumer onResponse); - public final void startNextScroll(TimeValue extraKeepAlive, Consumer onResponse) { + final void startNextScroll(TimeValue extraKeepAlive) { doStartNextScroll(scrollId.get(), extraKeepAlive, response -> { setScroll(response.getScrollId()); - onResponse.accept(response); + onResponse(response); }); } protected abstract void doStartNextScroll(String scrollId, TimeValue extraKeepAlive, Consumer onResponse); + private void onResponse(Response response) { + setScroll(response.getScrollId()); + onResponse.accept(new AsyncResponse() { + private AtomicBoolean alreadyDone = new AtomicBoolean(); + @Override + public Response response() { + return response; + } + + @Override + public void done(TimeValue extraKeepAlive) { + assert alreadyDone.compareAndSet(false, true); + startNextScroll(extraKeepAlive); + } + }); + } + public final void close(Runnable onCompletion) { String scrollId = this.scrollId.get(); if (Strings.hasLength(scrollId)) { @@ -115,6 +137,19 @@ public final void setScroll(String scrollId) { this.scrollId.set(scrollId); } + public interface AsyncResponse { + /** + * The response data made available. + */ + Response response(); + + /** + * Called when done processing response to signal more data is needed. + * @param extraKeepAlive extra time to keep underlying scroll open. + */ + void done(TimeValue extraKeepAlive); + } + /** * Response from each scroll batch. */