Skip to content

Add circuit breaker to storing async response #73638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ private void ensureCapacity(int size) {
assert overflow.size() >= overflowSize;
}

/**
* Returns the current size of the buffer.
*
* @return the number of bytes in this output stream.
*/
public int size() {
return position;
}

@Override
public void writeBytes(byte[] b, int offset, int length) {
if (position < buffer.length) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.index.engine.DocumentMissingException;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.search.SearchService;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.tasks.Task;
Expand All @@ -51,13 +52,15 @@ public class TransportSubmitAsyncSearchAction extends HandledTransportAction<Sub
private static final Logger logger = LogManager.getLogger(TransportSubmitAsyncSearchAction.class);

private final NodeClient nodeClient;
private final CircuitBreakerService circuitBreakerService;
private final Function<SearchRequest, InternalAggregation.ReduceContext> requestToAggReduceContextBuilder;
private final TransportSearchAction searchAction;
private final ThreadContext threadContext;
private final AsyncTaskIndexService<AsyncSearchResponse> store;

@Inject
public TransportSubmitAsyncSearchAction(ClusterService clusterService,
CircuitBreakerService circuitBreakerService,
TransportService transportService,
ActionFilters actionFilters,
NamedWriteableRegistry registry,
Expand All @@ -67,6 +70,7 @@ public TransportSubmitAsyncSearchAction(ClusterService clusterService,
TransportSearchAction searchAction) {
super(SubmitAsyncSearchAction.NAME, transportService, actionFilters, SubmitAsyncSearchRequest::new);
this.nodeClient = nodeClient;
this.circuitBreakerService = circuitBreakerService;
this.requestToAggReduceContextBuilder = request -> searchService.aggReduceContextBuilder(request).forFinalReduction();
this.searchAction = searchAction;
this.threadContext = transportService.getThreadPool().getThreadContext();
Expand All @@ -92,6 +96,7 @@ public void onResponse(AsyncSearchResponse searchResponse) {
// TODO: store intermediate results ?
AsyncSearchResponse initialResp = searchResponse.clone(searchResponse.getId());
store.createResponse(docId, searchTask.getOriginHeaders(), initialResp,
circuitBreakerService,
new ActionListener<>() {
@Override
public void onResponse(IndexResponse r) {
Expand Down Expand Up @@ -175,7 +180,7 @@ private void onFatalFailure(AsyncSearchTask task, Exception error, boolean shoul
private void onFinalResponse(AsyncSearchTask searchTask,
AsyncSearchResponse response,
Runnable nextAction) {
store.updateResponse(searchTask.getExecutionId().getDocId(), threadContext.getResponseHeaders(),response,
store.updateResponse(searchTask.getExecutionId().getDocId(), threadContext.getResponseHeaders(), response, circuitBreakerService,
ActionListener.wrap(resp -> unregisterTaskAndMoveOn(searchTask, nextAction),
exc -> {
Throwable cause = ExceptionsHelper.unwrapCause(exc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.TriFunction;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.breaker.CircuitBreaker;
import org.elasticsearch.common.breaker.CircuitBreakingException;
import org.elasticsearch.common.bytes.RecyclingBytesStreamOutput;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.ByteBufferStreamInput;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.lease.Releasable;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskManager;
Expand Down Expand Up @@ -180,16 +184,23 @@ public Authentication getAuthentication() {
public void createResponse(String docId,
Map<String, String> headers,
R response,
ActionListener<IndexResponse> listener) throws IOException {
Map<String, Object> source = new HashMap<>();
source.put(HEADERS_FIELD, headers);
source.put(EXPIRATION_TIME_FIELD, response.getExpirationTime());
source.put(RESULT_FIELD, encodeResponse(response));
IndexRequest indexRequest = new IndexRequest(index)
.create(true)
.id(docId)
.source(source, XContentType.JSON);
clientWithOrigin.index(indexRequest, listener);
CircuitBreakerService circuitBreakerService,
ActionListener<IndexResponse> listener0) throws IOException {
AsyncResponseUpdateContext updateContext = new AsyncResponseUpdateContext(circuitBreakerService);
ActionListener<IndexResponse> listener = ActionListener.runAfter(listener0, () -> updateContext.close());
try {
Map<String, Object> source = new HashMap<>();
source.put(HEADERS_FIELD, headers);
source.put(EXPIRATION_TIME_FIELD, response.getExpirationTime());
source.put(RESULT_FIELD, encodeResponse(response, updateContext));
IndexRequest indexRequest = new IndexRequest(index)
.create(true)
.id(docId)
.source(source, XContentType.JSON);
clientWithOrigin.index(indexRequest, listener);
} catch(Exception e) {
listener.onFailure(e);
}
}

/**
Expand All @@ -198,11 +209,14 @@ public void createResponse(String docId,
public void updateResponse(String docId,
Map<String, List<String>> responseHeaders,
R response,
ActionListener<UpdateResponse> listener) {
CircuitBreakerService circuitBreakerService,
ActionListener<UpdateResponse> listener0) {
AsyncResponseUpdateContext updateContext = new AsyncResponseUpdateContext(circuitBreakerService);
ActionListener<UpdateResponse> listener = ActionListener.runAfter(listener0, () -> updateContext.close());
try {
Map<String, Object> source = new HashMap<>();
source.put(RESPONSE_HEADERS_FIELD, responseHeaders);
source.put(RESULT_FIELD, encodeResponse(response));
source.put(RESULT_FIELD, encodeResponse(response, updateContext));
UpdateRequest request = new UpdateRequest()
.index(index)
.id(docId)
Expand Down Expand Up @@ -453,13 +467,27 @@ boolean ensureAuthenticatedUserIsSame(Map<String, String> originHeaders, Authent
}

/**
* Encode the provided response in a binary form using base64 encoding.
* Encodes the provided response in a binary form using base64 encoding.
* Needs approximately up to 3.3X extra memory, where X is the original response size:
* - extra X bytes - for RecyclingBytesStreamOutput that encodes the response in an array of bytes,
* this memory allocation will be tracked automatically by BigArrays with circuitBreaker
* - up to X bytes – for converting bytes stream to bytes array
* - up to 1.3X bytes for encoded string, as Base64 adds around 33% overhead
* @throws CircuitBreakingException
*/
String encodeResponse(R response) throws IOException {
try (BytesStreamOutput out = new BytesStreamOutput()) {
String encodeResponse(R response, AsyncResponseUpdateContext updateContext) throws IOException {
BigArrays bigArrays = new BigArrays(
null, updateContext.circuitBreakerService(), CircuitBreaker.REQUEST).withCircuitBreaking();
// using RecyclingBytesStreamOutput allows to supply BigArrays with a circuit breaker
try (RecyclingBytesStreamOutput out = new RecyclingBytesStreamOutput(new byte[0], bigArrays)) {
Version.writeVersion(Version.CURRENT, out);
response.writeTo(out);
return Base64.getEncoder().encodeToString(BytesReference.toBytes(out.bytes()));

// need to check from circuitBreaker if additional 2.3X size is available
long estimatedSize = Math.round(out.size() * 2.3);
updateContext.addCircuitBreakerBytes(estimatedSize);

return Base64.getEncoder().encodeToString(out.toBytesRef().bytes);
}
}

Expand All @@ -485,4 +513,32 @@ public static void restoreResponseHeadersContext(ThreadContext threadContext, Ma
}
}
}

/**
* A helper class for updating async search responses to track the memory usage
*/
static class AsyncResponseUpdateContext implements Releasable {
private long circuitBreakerBytes = 0L;
private CircuitBreakerService circuitBreakerService;

AsyncResponseUpdateContext(CircuitBreakerService circuitBreakerService) {
assert circuitBreakerService != null : "Circuit breaker service must be provided when storing async search response!";
this.circuitBreakerService = circuitBreakerService;
}

public CircuitBreakerService circuitBreakerService() {
return circuitBreakerService;
}

public void addCircuitBreakerBytes(long estimatedSize) {
circuitBreakerService.getBreaker(CircuitBreaker.REQUEST)
.addEstimateBytesAndMaybeBreak(estimatedSize, "<storing_async_search_response>");
circuitBreakerBytes += estimatedSize;
}

@Override
public void close() {
circuitBreakerService.getBreaker(CircuitBreaker.REQUEST).addWithoutBreaking(-circuitBreakerBytes);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.tasks.CancellableTask;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -39,6 +40,7 @@

public class AsyncResultsServiceTests extends ESSingleNodeTestCase {
private ClusterService clusterService;
private CircuitBreakerService circuitBreakerService;
private TaskManager taskManager;
private AsyncTaskIndexService<TestAsyncResponse> indexService;

Expand Down Expand Up @@ -122,6 +124,7 @@ public Task createTask(long id, String type, String action, TaskId parentTaskId,
@Before
public void setup() {
clusterService = getInstanceFromNode(ClusterService.class);
circuitBreakerService = getInstanceFromNode(CircuitBreakerService.class);
TransportService transportService = getInstanceFromNode(TransportService.class);
taskManager = transportService.getTaskManager();
indexService = new AsyncTaskIndexService<>("test", clusterService, transportService.getThreadPool().getThreadContext(),
Expand Down Expand Up @@ -162,7 +165,7 @@ public void testRetrieveFromMemoryWithExpiration() throws Exception {
// we need to store initial result
PlainActionFuture<IndexResponse> future = new PlainActionFuture<>();
indexService.createResponse(task.getExecutionId().getDocId(), task.getOriginHeaders(),
new TestAsyncResponse(null, task.getExpirationTime()), future);
new TestAsyncResponse(null, task.getExpirationTime()), circuitBreakerService, future);
future.actionGet(TimeValue.timeValueSeconds(10));
}

Expand Down Expand Up @@ -204,7 +207,7 @@ public void testAssertExpirationPropagation() throws Exception {
// we need to store initial result
PlainActionFuture<IndexResponse> future = new PlainActionFuture<>();
indexService.createResponse(task.getExecutionId().getDocId(), task.getOriginHeaders(),
new TestAsyncResponse(null, task.getExpirationTime()), future);
new TestAsyncResponse(null, task.getExpirationTime()), circuitBreakerService, future);
future.actionGet(TimeValue.timeValueSeconds(10));
}

Expand Down Expand Up @@ -242,17 +245,17 @@ public void testRetrieveFromDisk() throws Exception {
// we need to store initial result
PlainActionFuture<IndexResponse> futureCreate = new PlainActionFuture<>();
indexService.createResponse(task.getExecutionId().getDocId(), task.getOriginHeaders(),
new TestAsyncResponse(null, task.getExpirationTime()), futureCreate);
new TestAsyncResponse(null, task.getExpirationTime()), circuitBreakerService, futureCreate);
futureCreate.actionGet(TimeValue.timeValueSeconds(10));

PlainActionFuture<UpdateResponse> futureUpdate = new PlainActionFuture<>();
indexService.updateResponse(task.getExecutionId().getDocId(), emptyMap(),
new TestAsyncResponse("final_response", task.getExpirationTime()), futureUpdate);
new TestAsyncResponse("final_response", task.getExpirationTime()), circuitBreakerService, futureUpdate);
futureUpdate.actionGet(TimeValue.timeValueSeconds(10));
} else {
PlainActionFuture<IndexResponse> futureCreate = new PlainActionFuture<>();
indexService.createResponse(task.getExecutionId().getDocId(), task.getOriginHeaders(),
new TestAsyncResponse("final_response", task.getExpirationTime()), futureCreate);
new TestAsyncResponse("final_response", task.getExpirationTime()), circuitBreakerService, futureCreate);
futureCreate.actionGet(TimeValue.timeValueSeconds(10));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.test.ESSingleNodeTestCase;
import org.elasticsearch.transport.TransportService;
import org.junit.Before;
Expand All @@ -22,6 +23,7 @@
// TODO: test CRUD operations
public class AsyncSearchIndexServiceTests extends ESSingleNodeTestCase {
private AsyncTaskIndexService<TestAsyncResponse> indexService;
private CircuitBreakerService circuitBreakerService;

public static class TestAsyncResponse implements AsyncResponse<TestAsyncResponse> {
public final String test;
Expand Down Expand Up @@ -72,16 +74,20 @@ public int hashCode() {
public void setup() {
ClusterService clusterService = getInstanceFromNode(ClusterService.class);
TransportService transportService = getInstanceFromNode(TransportService.class);
circuitBreakerService = getInstanceFromNode(CircuitBreakerService.class);
indexService = new AsyncTaskIndexService<>("test", clusterService, transportService.getThreadPool().getThreadContext(),
client(), ASYNC_SEARCH_ORIGIN, TestAsyncResponse::new, writableRegistry());
}

public void testEncodeSearchResponse() throws IOException {
for (int i = 0; i < 10; i++) {
TestAsyncResponse response = new TestAsyncResponse(randomAlphaOfLength(10), randomLong());
String encoded = indexService.encodeResponse(response);
TestAsyncResponse same = indexService.decodeResponse(encoded);
assertThat(same, equalTo(response));
try (AsyncTaskIndexService.AsyncResponseUpdateContext updateContext =
new AsyncTaskIndexService.AsyncResponseUpdateContext(circuitBreakerService)) {
String encoded = indexService.encodeResponse(response, updateContext);
TestAsyncResponse same = indexService.decodeResponse(encoded);
assertThat(same, equalTo(response));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SystemIndexPlugin;
import org.elasticsearch.tasks.TaskId;
Expand All @@ -36,13 +37,15 @@
// TODO: test CRUD operations
public class AsyncTaskServiceTests extends ESSingleNodeTestCase {
private AsyncTaskIndexService<AsyncSearchResponse> indexService;
private CircuitBreakerService circuitBreakerService;

public String index = ".async-search";

@Before
public void setup() {
ClusterService clusterService = getInstanceFromNode(ClusterService.class);
TransportService transportService = getInstanceFromNode(TransportService.class);
circuitBreakerService = getInstanceFromNode(CircuitBreakerService.class);
indexService = new AsyncTaskIndexService<>(index, clusterService,
transportService.getThreadPool().getThreadContext(),
client(), "test_origin", AsyncSearchResponse::new, writableRegistry());
Expand Down Expand Up @@ -138,7 +141,7 @@ public void testAutoCreateIndex() throws Exception {
AsyncSearchResponse resp = new AsyncSearchResponse(id.getEncoded(), true, true, 0L, 0L);
{
PlainActionFuture<IndexResponse> future = PlainActionFuture.newFuture();
indexService.createResponse(id.getDocId(), Collections.emptyMap(), resp, future);
indexService.createResponse(id.getDocId(), Collections.emptyMap(), resp, circuitBreakerService, future);
future.get();
assertSettings();
}
Expand All @@ -157,7 +160,7 @@ public void testAutoCreateIndex() throws Exception {
// So do updates
{
PlainActionFuture<UpdateResponse> future = PlainActionFuture.newFuture();
indexService.updateResponse(id.getDocId(), Collections.emptyMap(), resp, future);
indexService.updateResponse(id.getDocId(), Collections.emptyMap(), resp, circuitBreakerService, future);
expectThrows(Exception.class, future::get);
assertSettings();
}
Expand All @@ -173,7 +176,7 @@ public void testAutoCreateIndex() throws Exception {
// But the index is still auto-created
{
PlainActionFuture<IndexResponse> future = PlainActionFuture.newFuture();
indexService.createResponse(id.getDocId(), Collections.emptyMap(), resp, future);
indexService.createResponse(id.getDocId(), Collections.emptyMap(), resp, circuitBreakerService, future);
future.get();
assertSettings();
}
Expand Down
Loading