Skip to content

HLRest: Allow caller to set per request options #30490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 31, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ public final TasksClient tasks() {
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
*/
public final BulkResponse bulk(BulkRequest bulkRequest, RequestOptions options) throws IOException {
return performRequestAndParseEntity(bulkRequest, RequestConverters::bulk, options, BulkResponse::fromXContent, emptySet());
}

/**
* Executes a bulk request using the Bulk API
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
* @deprecated Prefer {@link #bulk(BulkRequest, RequestOptions)}
*/
@Deprecated
public final BulkResponse bulk(BulkRequest bulkRequest, Header... headers) throws IOException {
return performRequestAndParseEntity(bulkRequest, RequestConverters::bulk, BulkResponse::fromXContent, emptySet(), headers);
}
Expand All @@ -288,6 +299,17 @@ public final BulkResponse bulk(BulkRequest bulkRequest, Header... headers) throw
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
*/
public final void bulkAsync(BulkRequest bulkRequest, RequestOptions options, ActionListener<BulkResponse> listener) {
performRequestAsyncAndParseEntity(bulkRequest, RequestConverters::bulk, options, BulkResponse::fromXContent, listener, emptySet());
}

/**
* Asynchronously executes a bulk request using the Bulk API
*
* See <a href="https://www.elastic.co/guide/en/elasticsearch/reference/current/docs-bulk.html">Bulk API on elastic.co</a>
* @deprecated Prefer {@link #bulkAsync(BulkRequest, RequestOptions, ActionListener)}
*/
@Deprecated
public final void bulkAsync(BulkRequest bulkRequest, ActionListener<BulkResponse> listener, Header... headers) {
performRequestAsyncAndParseEntity(bulkRequest, RequestConverters::bulk, BulkResponse::fromXContent, listener, emptySet(), headers);
}
Expand Down Expand Up @@ -584,23 +606,42 @@ public final void fieldCapsAsync(FieldCapabilitiesRequest fieldCapabilitiesReque
FieldCapabilitiesResponse::fromXContent, listener, emptySet(), headers);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Set<Integer> ignores, Header... headers) throws IOException {
return performRequest(request, requestConverter, (response) -> parseEntity(response.getEntity(), entityParser), ignores, headers);
}

protected final <Req extends ActionRequest, Resp> Resp performRequestAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Set<Integer> ignores) throws IOException {
return performRequest(request, requestConverter, options,
response -> parseEntity(response.getEntity(), entityParser), ignores);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> Resp performRequest(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<Response, Resp, IOException> responseConverter,
Set<Integer> ignores, Header... headers) throws IOException {
return performRequest(request, requestConverter, optionsForHeaders(headers), responseConverter, ignores);
}

protected final <Req extends ActionRequest, Resp> Resp performRequest(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<Response, Resp, IOException> responseConverter,
Set<Integer> ignores) throws IOException {
ActionRequestValidationException validationException = request.validate();
if (validationException != null) {
throw validationException;
}
Request req = requestConverter.apply(request);
addHeaders(req, headers);
req.setOptions(options);
Response response;
try {
response = client.performRequest(req);
Expand All @@ -626,6 +667,7 @@ protected final <Req extends ActionRequest, Resp> Resp performRequest(Req reques
}
}

@Deprecated
protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
Expand All @@ -634,10 +676,28 @@ protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndPar
listener, ignores, headers);
}

protected final <Req extends ActionRequest, Resp> void performRequestAsyncAndParseEntity(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<XContentParser, Resp, IOException> entityParser,
ActionListener<Resp> listener, Set<Integer> ignores) {
performRequestAsync(request, requestConverter, options,
response -> parseEntity(response.getEntity(), entityParser), listener, ignores);
}

@Deprecated
protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> listener, Set<Integer> ignores, Header... headers) {
performRequestAsync(request, requestConverter, optionsForHeaders(headers), responseConverter, listener, ignores);
}

protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req request,
CheckedFunction<Req, Request, IOException> requestConverter,
RequestOptions options,
CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> listener, Set<Integer> ignores) {
ActionRequestValidationException validationException = request.validate();
if (validationException != null) {
listener.onFailure(validationException);
Expand All @@ -650,19 +710,12 @@ protected final <Req extends ActionRequest, Resp> void performRequestAsync(Req r
listener.onFailure(e);
return;
}
addHeaders(req, headers);
req.setOptions(options);

ResponseListener responseListener = wrapResponseListener(responseConverter, listener, ignores);
client.performRequestAsync(req, responseListener);
}

private static void addHeaders(Request request, Header... headers) {
Objects.requireNonNull(headers, "headers cannot be null");
for (Header header : headers) {
request.addHeader(header.getName(), header.getValue());
}
}

final <Resp> ResponseListener wrapResponseListener(CheckedFunction<Response, Resp, IOException> responseConverter,
ActionListener<Resp> actionListener, Set<Integer> ignores) {
return new ResponseListener() {
Expand Down Expand Up @@ -746,6 +799,15 @@ protected final <Resp> Resp parseEntity(final HttpEntity entity,
}
}

private static RequestOptions optionsForHeaders(Header[] headers) {
RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
for (Header header : headers) {
Objects.requireNonNull(header, "header cannot be null");
options.addHeader(header.getName(), header.getValue());
}
return options.build();
}

static boolean convertExistsResponse(Response response) {
return response.getStatusLine().getStatusCode() == 200;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.apache.http.client.methods.HttpGet;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.http.message.BasicRequestLine;
import org.apache.http.message.BasicStatusLine;
import org.apache.lucene.util.BytesRef;
Expand All @@ -48,11 +47,13 @@
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import java.util.stream.Collectors;

import static java.util.Collections.emptySet;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.hasSize;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
Expand All @@ -73,12 +74,12 @@ public void initClients() throws IOException {
final RestClient restClient = mock(RestClient.class);
restHighLevelClient = new CustomRestClient(restClient);

doAnswer(inv -> mockPerformRequest(((Request) inv.getArguments()[0]).getHeaders().iterator().next()))
doAnswer(inv -> mockPerformRequest((Request) inv.getArguments()[0]))
.when(restClient)
.performRequest(any(Request.class));

doAnswer(inv -> mockPerformRequestAsync(
((Request) inv.getArguments()[0]).getHeaders().iterator().next(),
((Request) inv.getArguments()[0]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you fix the javadocs link for mockPerformRequestAsync ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

(ResponseListener) inv.getArguments()[1]))
.when(restClient)
.performRequestAsync(any(Request.class), any(ResponseListener.class));
Expand All @@ -87,26 +88,32 @@ public void initClients() throws IOException {

public void testCustomEndpoint() throws IOException {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));
String nodeName = randomAlphaOfLengthBetween(1, 10);

MainResponse response = restHighLevelClient.custom(request, header);
assertEquals(header.getValue(), response.getNodeName());
MainResponse response = restHighLevelClient.custom(request, optionsForNodeName(nodeName));
assertEquals(nodeName, response.getNodeName());

response = restHighLevelClient.customAndParse(request, header);
assertEquals(header.getValue(), response.getNodeName());
response = restHighLevelClient.customAndParse(request, optionsForNodeName(nodeName));
assertEquals(nodeName, response.getNodeName());
}

public void testCustomEndpointAsync() throws Exception {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));
String nodeName = randomAlphaOfLengthBetween(1, 10);

PlainActionFuture<MainResponse> future = PlainActionFuture.newFuture();
restHighLevelClient.customAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());
restHighLevelClient.customAsync(request, optionsForNodeName(nodeName), future);
assertEquals(nodeName, future.get().getNodeName());

future = PlainActionFuture.newFuture();
restHighLevelClient.customAndParseAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());
restHighLevelClient.customAndParseAsync(request, optionsForNodeName(nodeName), future);
assertEquals(nodeName, future.get().getNodeName());
}

private static RequestOptions optionsForNodeName(String nodeName) {
RequestOptions.Builder options = RequestOptions.DEFAULT.toBuilder();
options.addHeader("node_name", nodeName);
return options.build();
}

/**
Expand All @@ -115,27 +122,27 @@ public void testCustomEndpointAsync() throws Exception {
*/
@SuppressForbidden(reason = "We're forced to uses Class#getDeclaredMethods() here because this test checks protected methods")
public void testMethodsVisibility() throws ClassNotFoundException {
final String[] methodNames = new String[]{"performRequest",
"performRequestAsync",
final String[] methodNames = new String[]{"parseEntity",
"parseResponseException",
"performRequest",
"performRequestAndParseEntity",
"performRequestAsyncAndParseEntity",
"parseEntity",
"parseResponseException"};
"performRequestAsync",
"performRequestAsyncAndParseEntity"};

final List<String> protectedMethods = Arrays.stream(RestHighLevelClient.class.getDeclaredMethods())
final Set<String> protectedMethods = Arrays.stream(RestHighLevelClient.class.getDeclaredMethods())
.filter(method -> Modifier.isProtected(method.getModifiers()))
.map(Method::getName)
.collect(Collectors.toList());
.collect(Collectors.toCollection(TreeSet::new));

assertThat(protectedMethods, containsInAnyOrder(methodNames));
assertThat(protectedMethods, contains(methodNames));
}

/**
* Mocks the asynchronous request execution by calling the {@link #mockPerformRequest(Header)} method.
* Mocks the asynchronous request execution by calling the {@link #mockPerformRequest(Request)} method.
*/
private Void mockPerformRequestAsync(Header httpHeader, ResponseListener responseListener) {
private Void mockPerformRequestAsync(Request request, ResponseListener responseListener) {
try {
responseListener.onSuccess(mockPerformRequest(httpHeader));
responseListener.onSuccess(mockPerformRequest(request));
} catch (IOException e) {
responseListener.onFailure(e);
}
Expand All @@ -145,7 +152,9 @@ private Void mockPerformRequestAsync(Header httpHeader, ResponseListener respons
/**
* Mocks the synchronous request execution like if it was executed by Elasticsearch.
*/
private Response mockPerformRequest(Header httpHeader) throws IOException {
private Response mockPerformRequest(Request request) throws IOException {
assertThat(request.getOptions().getHeaders(), hasSize(1));
Header httpHeader = request.getOptions().getHeaders().get(0);
final Response mockResponse = mock(Response.class);
when(mockResponse.getHost()).thenReturn(new HttpHost("localhost", 9200));

Expand All @@ -171,20 +180,20 @@ private CustomRestClient(RestClient restClient) {
super(restClient, RestClient::close, Collections.emptyList());
}

MainResponse custom(MainRequest mainRequest, Header... headers) throws IOException {
return performRequest(mainRequest, this::toRequest, this::toResponse, emptySet(), headers);
MainResponse custom(MainRequest mainRequest, RequestOptions options) throws IOException {
return performRequest(mainRequest, this::toRequest, options, this::toResponse, emptySet());
}

MainResponse customAndParse(MainRequest mainRequest, Header... headers) throws IOException {
return performRequestAndParseEntity(mainRequest, this::toRequest, MainResponse::fromXContent, emptySet(), headers);
MainResponse customAndParse(MainRequest mainRequest, RequestOptions options) throws IOException {
return performRequestAndParseEntity(mainRequest, this::toRequest, options, MainResponse::fromXContent, emptySet());
}

void customAsync(MainRequest mainRequest, ActionListener<MainResponse> listener, Header... headers) {
performRequestAsync(mainRequest, this::toRequest, this::toResponse, listener, emptySet(), headers);
void customAsync(MainRequest mainRequest, RequestOptions options, ActionListener<MainResponse> listener) {
performRequestAsync(mainRequest, this::toRequest, options, this::toResponse, listener, emptySet());
}

void customAndParseAsync(MainRequest mainRequest, ActionListener<MainResponse> listener, Header... headers) {
performRequestAsyncAndParseEntity(mainRequest, this::toRequest, MainResponse::fromXContent, listener, emptySet(), headers);
void customAndParseAsync(MainRequest mainRequest, RequestOptions options, ActionListener<MainResponse> listener) {
performRequestAsyncAndParseEntity(mainRequest, this::toRequest, options, MainResponse::fromXContent, listener, emptySet());
}

Request toRequest(MainRequest mainRequest) throws IOException {
Expand Down
Loading