Skip to content

Extend async search keep alive #67877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import org.apache.lucene.store.AlreadyClosedException;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
Expand Down Expand Up @@ -345,7 +346,7 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(6));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

assertTrue(response.isRunning());
Expand All @@ -364,8 +365,13 @@ public void testUpdateRunningKeepAlive() throws Exception {
assertEquals(0, statusResponse.getSkippedShards());
assertEquals(null, statusResponse.getCompletionStatus());

response = getAsyncSearch(response.getId(), TimeValue.timeValueMillis(1));
assertThat(response.getExpirationTime(), lessThan(expirationTime));
expirationTime = response.getExpirationTime();
response = getAsyncSearch(response.getId(), TimeValue.timeValueMinutes(between(1, 24 * 60)));
assertThat(response.getExpirationTime(), equalTo(expirationTime));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

deleteAsyncSearch(response.getId());
ensureTaskNotRunning(response.getId());
ensureTaskRemoval(response.getId());
}
Expand All @@ -391,16 +397,21 @@ public void testUpdateStoreKeepAlive() throws Exception {
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(8));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));
expirationTime = response.getExpirationTime();

assertFalse(response.isRunning());
assertThat(response.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getSuccessfulShards(), equalTo(numShards));
assertThat(response.getSearchResponse().getFailedShards(), equalTo(0));

response = getAsyncSearch(response.getId(), TimeValue.timeValueMillis(1));
assertThat(response.getExpirationTime(), lessThan(expirationTime));
assertThat(response.getExpirationTime(), equalTo(expirationTime));
response = getAsyncSearch(response.getId(), TimeValue.timeValueDays(10));
assertThat(response.getExpirationTime(), greaterThan(expirationTime));

deleteAsyncSearch(response.getId());
ensureTaskNotRunning(response.getId());
ensureTaskRemoval(response.getId());
}
Expand All @@ -427,22 +438,24 @@ public void testRemoveAsyncIndex() throws Exception {
ExceptionsHelper.unwrapCause(exc.getCause()) : ExceptionsHelper.unwrapCause(exc);
assertThat(ExceptionsHelper.status(cause).getStatus(), equalTo(404));

SubmitAsyncSearchRequest newReq = new SubmitAsyncSearchRequest(indexName);
SubmitAsyncSearchRequest newReq = new SubmitAsyncSearchRequest(indexName) {
@Override
public ActionRequestValidationException validate() {
return null; // to use a small keep_alive
}
};
newReq.getSearchRequest().source(
new SearchSourceBuilder().aggregation(new CancellingAggregationBuilder("test", randomLong()))
);
newReq.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1));
newReq.setWaitForCompletionTimeout(TimeValue.timeValueMillis(1)).setKeepAlive(TimeValue.timeValueSeconds(5));
AsyncSearchResponse newResp = submitAsyncSearch(newReq);
assertNotNull(newResp.getSearchResponse());
assertTrue(newResp.isRunning());
assertThat(newResp.getSearchResponse().getTotalShards(), equalTo(numShards));
assertThat(newResp.getSearchResponse().getSuccessfulShards(), equalTo(0));
assertThat(newResp.getSearchResponse().getFailedShards(), equalTo(0));
long expirationTime = newResp.getExpirationTime();

// check garbage collection
newResp = getAsyncSearch(newResp.getId(), TimeValue.timeValueMillis(1));
assertThat(newResp.getExpirationTime(), lessThan(expirationTime));
ensureTaskNotRunning(newResp.getId());
ensureTaskRemoval(newResp.getId());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.script.MockScriptPlugin;
import org.elasticsearch.search.aggregations.bucket.filter.InternalFilter;
import org.elasticsearch.search.builder.PointInTimeBuilder;
import org.elasticsearch.search.builder.SearchSourceBuilder;
Expand Down Expand Up @@ -57,11 +58,15 @@
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import static org.elasticsearch.xpack.core.XPackPlugin.ASYNC_RESULTS_INDEX;
import static org.elasticsearch.xpack.core.async.AsyncTaskMaintenanceService.ASYNC_SEARCH_CLEANUP_INTERVAL_SETTING;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThanOrEqualTo;

Expand Down Expand Up @@ -93,6 +98,37 @@ public List<AggregationSpec> getAggregations() {
}
}

public static class ExpirationTimeScriptPlugin extends MockScriptPlugin {
@Override
public String pluginScriptLang() {
return "painless";
}

@Override
@SuppressWarnings("unchecked")
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
final String fieldName = "expiration_time";
final String script =
"if (ctx._source.expiration_time < params.expiration_time) ctx._source.expiration_time = params.expiration_time";
return Map.of(
script, vars -> {
Map<String, Object> params = (Map<String, Object>) vars.get("params");
assertNotNull(params);
assertThat(params.keySet(), contains(fieldName));
long updatingValue = (long) params.get(fieldName);

Map<String, Object> ctx = (Map<String, Object>) vars.get("ctx");
assertNotNull(ctx);
Map<String, Object> source = (Map<String, Object>) ctx.get("_source");
long currentValue = (long) source.get(fieldName);

source.put(fieldName, Math.max(currentValue, updatingValue));
return ctx;
}
);
}
}

@Before
public void startMaintenanceService() {
for (AsyncTaskMaintenanceService service : internalCluster().getDataNodeInstances(AsyncTaskMaintenanceService.class)) {
Expand Down Expand Up @@ -120,7 +156,7 @@ public void releaseQueryLatch() {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(LocalStateCompositeXPackPlugin.class, AsyncSearch.class, AsyncResultsIndexPlugin.class, IndexLifecycle.class,
SearchTestPlugin.class, ReindexPlugin.class);
SearchTestPlugin.class, ReindexPlugin.class, ExpirationTimeScriptPlugin.class);
}

@Override
Expand Down Expand Up @@ -189,7 +225,7 @@ protected void ensureTaskNotRunning(String id) throws Exception {
throw exc;
}
}
});
}, 30, TimeUnit.SECONDS);
}

/**
Expand All @@ -207,7 +243,7 @@ protected void ensureTaskCompletion(String id) throws Exception {
throw exc;
}
}
});
}, 30, TimeUnit.SECONDS);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Supplier;
Expand All @@ -62,7 +63,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
private final List<Runnable> initListeners = new ArrayList<>();
private final Map<Long, Consumer<AsyncSearchResponse>> completionListeners = new HashMap<>();

private volatile long expirationTimeMillis;
private final AtomicLong expirationTimeMillis;
private final AtomicBoolean isCancelling = new AtomicBoolean(false);

private final AtomicReference<MutableSearchResponse> searchResponse = new AtomicReference<>();
Expand Down Expand Up @@ -93,7 +94,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask {
ThreadPool threadPool,
Supplier<InternalAggregation.ReduceContext> aggReduceContextSupplier) {
super(id, type, action, () -> "async_search{" + descriptionSupplier.get() + "}", parentTaskId, taskHeaders);
this.expirationTimeMillis = getStartTime() + keepAlive.getMillis();
this.expirationTimeMillis = new AtomicLong(getStartTime() + keepAlive.getMillis());
this.originHeaders = originHeaders;
this.searchId = searchId;
this.client = client;
Expand Down Expand Up @@ -128,7 +129,7 @@ Listener getSearchProgressActionListener() {
*/
@Override
public void setExpirationTime(long expirationTimeMillis) {
this.expirationTimeMillis = expirationTimeMillis;
this.expirationTimeMillis.updateAndGet(curr -> Math.max(curr, expirationTimeMillis));
}

@Override
Expand Down Expand Up @@ -330,19 +331,19 @@ private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) {
checkCancellation();
AsyncSearchResponse asyncSearchResponse;
try {
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, restoreResponseHeaders);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis.get(), restoreResponseHeaders);
} catch(Exception e) {
ElasticsearchException exception = new ElasticsearchStatusException("Async search: error while reducing partial results",
ExceptionsHelper.status(e), e);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, exception);
asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis.get(), exception);
}
return asyncSearchResponse;
}

// checks if the search task should be cancelled
private synchronized void checkCancellation() {
long now = System.currentTimeMillis();
if (hasCompleted == false && expirationTimeMillis < now) {
if (hasCompleted == false && expirationTimeMillis.get() < now) {
// we cancel expired search task even if they are still running
cancelTask(() -> {}, "async search has expired");
}
Expand All @@ -354,7 +355,7 @@ private synchronized void checkCancellation() {
public AsyncStatusResponse getStatusResponse() {
MutableSearchResponse mutableSearchResponse = searchResponse.get();
assert mutableSearchResponse != null;
return mutableSearchResponse.toStatusResponse(searchId.getEncoded(), getStartTime(), expirationTimeMillis);
return mutableSearchResponse.toStatusResponse(searchId.getEncoded(), getStartTime(), expirationTimeMillis.get());
}

class Listener extends SearchProgressActionListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.xpack.core.XPackPlugin;
Expand All @@ -45,9 +47,9 @@
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.function.BiFunction;

Expand All @@ -64,6 +66,10 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
public static final String HEADERS_FIELD = "headers";
public static final String RESPONSE_HEADERS_FIELD = "response_headers";
public static final String EXPIRATION_TIME_FIELD = "expiration_time";
public static final String EXPIRATION_TIME_SCRIPT = String.format(Locale.ROOT,
"if (ctx._source.%s < params.%s) ctx._source.%s = params.%s",
EXPIRATION_TIME_FIELD, EXPIRATION_TIME_FIELD, EXPIRATION_TIME_FIELD, EXPIRATION_TIME_FIELD);

public static final String RESULT_FIELD = "result";

// Usually the settings, mappings and system index descriptor below
Expand Down Expand Up @@ -202,10 +208,12 @@ public void updateResponse(String docId,
public void updateExpirationTime(String docId,
long expirationTimeMillis,
ActionListener<UpdateResponse> listener) {
Map<String, Object> source = Collections.singletonMap(EXPIRATION_TIME_FIELD, expirationTimeMillis);
UpdateRequest request = new UpdateRequest().index(index)
Script script = new Script(ScriptType.INLINE, "painless", EXPIRATION_TIME_SCRIPT,
Map.of(EXPIRATION_TIME_FIELD, expirationTimeMillis));
UpdateRequest request = new UpdateRequest()
.index(index)
.id(docId)
.doc(source, XContentType.JSON)
.script(script)
.retryOnConflict(5);
client.update(request, listener);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -23,6 +24,7 @@
import org.elasticsearch.xpack.core.async.AsyncSearchIndexServiceTests.TestAsyncResponse;
import org.junit.Before;

import java.util.Collection;
import java.util.HashMap;
import java.util.Map;

Expand Down Expand Up @@ -273,4 +275,9 @@ public void testRetrieveFromDisk() throws Exception {
deleteService.deleteResult(new DeleteAsyncResultRequest(task.getExecutionId().getEncoded()), deleteListener);
assertFutureThrows(deleteListener, ResourceNotFoundException.class);
}

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ExpirationTimeScriptPlugin.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public void setup() {
protected Collection<Class<? extends Plugin>> getPlugins() {
List<Class<? extends Plugin>> plugins = new ArrayList<>(super.getPlugins());
plugins.add(TestPlugin.class);
plugins.add(ExpirationTimeScriptPlugin.class);
return plugins;
}

Expand Down Expand Up @@ -175,4 +176,6 @@ private void assertSettings() {
Settings expected = AsyncTaskIndexService.settings();
assertEquals(expected, settings.filter(expected::hasValue));
}


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.xpack.core.async;

import org.elasticsearch.script.MockScriptPlugin;
import org.junit.Assert;

import java.util.Map;
import java.util.function.Function;

import static org.hamcrest.Matchers.contains;

public class ExpirationTimeScriptPlugin extends MockScriptPlugin {
@Override
public String pluginScriptLang() {
return "painless";
}

@Override
@SuppressWarnings("unchecked")
protected Map<String, Function<Map<String, Object>, Object>> pluginScripts() {
final String fieldName = "expiration_time";
final String script =
"if (ctx._source.expiration_time < params.expiration_time) ctx._source.expiration_time = params.expiration_time";
return Map.of(
script, vars -> {
Map<String, Object> params = (Map<String, Object>) vars.get("params");
Assert.assertNotNull(params);
Assert.assertThat(params.keySet(), contains(fieldName));
long updatingValue = (long) params.get(fieldName);

Map<String, Object> ctx = (Map<String, Object>) vars.get("ctx");
Assert.assertNotNull(ctx);
Map<String, Object> source = (Map<String, Object>) ctx.get("_source");
long currentValue = (long) source.get(fieldName);

source.put(fieldName, Math.max(currentValue, updatingValue));
return ctx;
}
);
}
}
Loading