Skip to content

Commit f0f75b1

Browse files
authored
Support Preemptive Authentication with RestClient (elastic#21336)
This adds the necessary `AuthCache` needed to support preemptive authorization. By adding every host to the cache, the automatically added `RequestAuthCache` interceptor will add credentials on the first pass rather than waiting to do it after _each_ anonymous request is rejected (thus always sending everything twice when basic auth is required).
1 parent 47c0e13 commit f0f75b1

File tree

6 files changed

+157
-41
lines changed

6 files changed

+157
-41
lines changed

client/rest/src/main/java/org/elasticsearch/client/RestClient.java

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.http.HttpHost;
2626
import org.apache.http.HttpRequest;
2727
import org.apache.http.HttpResponse;
28+
import org.apache.http.client.AuthCache;
2829
import org.apache.http.client.ClientProtocolException;
2930
import org.apache.http.client.methods.HttpEntityEnclosingRequestBase;
3031
import org.apache.http.client.methods.HttpHead;
@@ -34,8 +35,11 @@
3435
import org.apache.http.client.methods.HttpPut;
3536
import org.apache.http.client.methods.HttpRequestBase;
3637
import org.apache.http.client.methods.HttpTrace;
38+
import org.apache.http.client.protocol.HttpClientContext;
3739
import org.apache.http.client.utils.URIBuilder;
3840
import org.apache.http.concurrent.FutureCallback;
41+
import org.apache.http.impl.auth.BasicScheme;
42+
import org.apache.http.impl.client.BasicAuthCache;
3943
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
4044
import org.apache.http.nio.client.methods.HttpAsyncMethods;
4145
import org.apache.http.nio.protocol.HttpAsyncRequestProducer;
@@ -92,7 +96,7 @@ public class RestClient implements Closeable {
9296
private final long maxRetryTimeoutMillis;
9397
private final String pathPrefix;
9498
private final AtomicInteger lastHostIndex = new AtomicInteger(0);
95-
private volatile Set<HttpHost> hosts;
99+
private volatile HostTuple<Set<HttpHost>> hostTuple;
96100
private final ConcurrentMap<HttpHost, DeadHostState> blacklist = new ConcurrentHashMap<>();
97101
private final FailureListener failureListener;
98102

@@ -122,11 +126,13 @@ public synchronized void setHosts(HttpHost... hosts) {
122126
throw new IllegalArgumentException("hosts must not be null nor empty");
123127
}
124128
Set<HttpHost> httpHosts = new HashSet<>();
129+
AuthCache authCache = new BasicAuthCache();
125130
for (HttpHost host : hosts) {
126131
Objects.requireNonNull(host, "host cannot be null");
127132
httpHosts.add(host);
133+
authCache.put(host, new BasicScheme());
128134
}
129-
this.hosts = Collections.unmodifiableSet(httpHosts);
135+
this.hostTuple = new HostTuple<>(Collections.unmodifiableSet(httpHosts), authCache);
130136
this.blacklist.clear();
131137
}
132138

@@ -315,19 +321,22 @@ public void performRequestAsync(String method, String endpoint, Map<String, Stri
315321
setHeaders(request, headers);
316322
FailureTrackingResponseListener failureTrackingResponseListener = new FailureTrackingResponseListener(responseListener);
317323
long startTime = System.nanoTime();
318-
performRequestAsync(startTime, nextHost().iterator(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
319-
failureTrackingResponseListener);
324+
performRequestAsync(startTime, nextHost(), request, ignoreErrorCodes, httpAsyncResponseConsumerFactory,
325+
failureTrackingResponseListener);
320326
}
321327

322-
private void performRequestAsync(final long startTime, final Iterator<HttpHost> hosts, final HttpRequestBase request,
328+
private void performRequestAsync(final long startTime, final HostTuple<Iterator<HttpHost>> hostTuple, final HttpRequestBase request,
323329
final Set<Integer> ignoreErrorCodes,
324330
final HttpAsyncResponseConsumerFactory httpAsyncResponseConsumerFactory,
325331
final FailureTrackingResponseListener listener) {
326-
final HttpHost host = hosts.next();
332+
final HttpHost host = hostTuple.hosts.next();
327333
//we stream the request body if the entity allows for it
328-
HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
329-
HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer = httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
330-
client.execute(requestProducer, asyncResponseConsumer, new FutureCallback<HttpResponse>() {
334+
final HttpAsyncRequestProducer requestProducer = HttpAsyncMethods.create(host, request);
335+
final HttpAsyncResponseConsumer<HttpResponse> asyncResponseConsumer =
336+
httpAsyncResponseConsumerFactory.createHttpAsyncResponseConsumer();
337+
final HttpClientContext context = HttpClientContext.create();
338+
context.setAuthCache(hostTuple.authCache);
339+
client.execute(requestProducer, asyncResponseConsumer, context, new FutureCallback<HttpResponse>() {
331340
@Override
332341
public void completed(HttpResponse httpResponse) {
333342
try {
@@ -366,7 +375,7 @@ public void failed(Exception failure) {
366375
}
367376

368377
private void retryIfPossible(Exception exception) {
369-
if (hosts.hasNext()) {
378+
if (hostTuple.hosts.hasNext()) {
370379
//in case we are retrying, check whether maxRetryTimeout has been reached
371380
long timeElapsedMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTime);
372381
long timeout = maxRetryTimeoutMillis - timeElapsedMillis;
@@ -377,7 +386,7 @@ private void retryIfPossible(Exception exception) {
377386
} else {
378387
listener.trackFailure(exception);
379388
request.reset();
380-
performRequestAsync(startTime, hosts, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
389+
performRequestAsync(startTime, hostTuple, request, ignoreErrorCodes, httpAsyncResponseConsumerFactory, listener);
381390
}
382391
} else {
383392
listener.onDefinitiveFailure(exception);
@@ -415,17 +424,18 @@ private void setHeaders(HttpRequest httpRequest, Header[] requestHeaders) {
415424
* The iterator returned will never be empty. In case there are no healthy hosts available, or dead ones to be be retried,
416425
* one dead host gets returned so that it can be retried.
417426
*/
418-
private Iterable<HttpHost> nextHost() {
427+
private HostTuple<Iterator<HttpHost>> nextHost() {
428+
final HostTuple<Set<HttpHost>> hostTuple = this.hostTuple;
419429
Collection<HttpHost> nextHosts = Collections.emptySet();
420430
do {
421-
Set<HttpHost> filteredHosts = new HashSet<>(hosts);
431+
Set<HttpHost> filteredHosts = new HashSet<>(hostTuple.hosts);
422432
for (Map.Entry<HttpHost, DeadHostState> entry : blacklist.entrySet()) {
423433
if (System.nanoTime() - entry.getValue().getDeadUntilNanos() < 0) {
424434
filteredHosts.remove(entry.getKey());
425435
}
426436
}
427437
if (filteredHosts.isEmpty()) {
428-
//last resort: if there are no good hosts to use, return a single dead one, the one that's closest to being retried
438+
//last resort: if there are no good host to use, return a single dead one, the one that's closest to being retried
429439
List<Map.Entry<HttpHost, DeadHostState>> sortedHosts = new ArrayList<>(blacklist.entrySet());
430440
if (sortedHosts.size() > 0) {
431441
Collections.sort(sortedHosts, new Comparator<Map.Entry<HttpHost, DeadHostState>>() {
@@ -444,7 +454,7 @@ public int compare(Map.Entry<HttpHost, DeadHostState> o1, Map.Entry<HttpHost, De
444454
nextHosts = rotatedHosts;
445455
}
446456
} while(nextHosts.isEmpty());
447-
return nextHosts;
457+
return new HostTuple<>(nextHosts.iterator(), hostTuple.authCache);
448458
}
449459

450460
/**
@@ -686,4 +696,18 @@ public void onFailure(HttpHost host) {
686696

687697
}
688698
}
699+
700+
/**
701+
* {@code HostTuple} enables the {@linkplain HttpHost}s and {@linkplain AuthCache} to be set together in a thread
702+
* safe, volatile way.
703+
*/
704+
private static class HostTuple<T> {
705+
public final T hosts;
706+
public final AuthCache authCache;
707+
708+
public HostTuple(final T hosts, final AuthCache authCache) {
709+
this.hosts = hosts;
710+
this.authCache = authCache;
711+
}
712+
}
689713
}

client/rest/src/test/java/org/elasticsearch/client/RestClientMultipleHostsTests.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
import org.apache.http.ProtocolVersion;
2727
import org.apache.http.StatusLine;
2828
import org.apache.http.client.methods.HttpUriRequest;
29+
import org.apache.http.client.protocol.HttpClientContext;
2930
import org.apache.http.concurrent.FutureCallback;
3031
import org.apache.http.conn.ConnectTimeoutException;
32+
import org.apache.http.impl.auth.BasicScheme;
3133
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
3234
import org.apache.http.message.BasicHttpResponse;
3335
import org.apache.http.message.BasicStatusLine;
@@ -73,13 +75,15 @@ public class RestClientMultipleHostsTests extends RestClientTestCase {
7375
public void createRestClient() throws IOException {
7476
CloseableHttpAsyncClient httpClient = mock(CloseableHttpAsyncClient.class);
7577
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
76-
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
78+
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
7779
@Override
7880
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
7981
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
8082
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
8183
HttpHost httpHost = requestProducer.getTarget();
82-
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
84+
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
85+
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
86+
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
8387
//return the desired status code or exception depending on the path
8488
if (request.getURI().getPath().equals("/soe")) {
8589
futureCallback.failed(new SocketTimeoutException(httpHost.toString()));

client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostIntegTests.java

Lines changed: 77 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
import org.apache.http.Consts;
2727
import org.apache.http.Header;
2828
import org.apache.http.HttpHost;
29+
import org.apache.http.auth.AuthScope;
30+
import org.apache.http.auth.UsernamePasswordCredentials;
2931
import org.apache.http.entity.StringEntity;
32+
import org.apache.http.impl.client.BasicCredentialsProvider;
33+
import org.apache.http.impl.nio.client.HttpAsyncClientBuilder;
3034
import org.apache.http.util.EntityUtils;
3135
import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement;
3236
import org.elasticsearch.mocksocket.MockHttpServer;
@@ -48,7 +52,10 @@
4852
import static org.elasticsearch.client.RestClientTestUtil.getAllStatusCodes;
4953
import static org.elasticsearch.client.RestClientTestUtil.getHttpMethods;
5054
import static org.elasticsearch.client.RestClientTestUtil.randomStatusCode;
55+
import static org.hamcrest.Matchers.nullValue;
56+
import static org.hamcrest.Matchers.startsWith;
5157
import static org.junit.Assert.assertEquals;
58+
import static org.junit.Assert.assertThat;
5259
import static org.junit.Assert.assertTrue;
5360

5461
/**
@@ -66,22 +73,10 @@ public class RestClientSingleHostIntegTests extends RestClientTestCase {
6673

6774
@BeforeClass
6875
public static void startHttpServer() throws Exception {
69-
String pathPrefixWithoutLeadingSlash;
70-
if (randomBoolean()) {
71-
pathPrefixWithoutLeadingSlash = "testPathPrefix/" + randomAsciiOfLengthBetween(1, 5);
72-
pathPrefix = "/" + pathPrefixWithoutLeadingSlash;
73-
} else {
74-
pathPrefix = pathPrefixWithoutLeadingSlash = "";
75-
}
76-
76+
pathPrefix = randomBoolean() ? "/testPathPrefix/" + randomAsciiOfLengthBetween(1, 5) : "";
7777
httpServer = createHttpServer();
7878
defaultHeaders = RestClientTestUtil.randomHeaders(getRandom(), "Header-default");
79-
RestClientBuilder restClientBuilder = RestClient.builder(
80-
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
81-
if (pathPrefix.length() > 0) {
82-
restClientBuilder.setPathPrefix((randomBoolean() ? "/" : "") + pathPrefixWithoutLeadingSlash);
83-
}
84-
restClient = restClientBuilder.build();
79+
restClient = createRestClient(false, true);
8580
}
8681

8782
private static HttpServer createHttpServer() throws Exception {
@@ -129,6 +124,35 @@ public void handle(HttpExchange httpExchange) throws IOException {
129124
}
130125
}
131126

127+
private static RestClient createRestClient(final boolean useAuth, final boolean usePreemptiveAuth) {
128+
// provide the username/password for every request
129+
final BasicCredentialsProvider credentialsProvider = new BasicCredentialsProvider();
130+
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("user", "pass"));
131+
132+
final RestClientBuilder restClientBuilder = RestClient.builder(
133+
new HttpHost(httpServer.getAddress().getHostString(), httpServer.getAddress().getPort())).setDefaultHeaders(defaultHeaders);
134+
if (pathPrefix.length() > 0) {
135+
// sometimes cut off the leading slash
136+
restClientBuilder.setPathPrefix(randomBoolean() ? pathPrefix.substring(1) : pathPrefix);
137+
}
138+
139+
if (useAuth) {
140+
restClientBuilder.setHttpClientConfigCallback(new RestClientBuilder.HttpClientConfigCallback() {
141+
@Override
142+
public HttpAsyncClientBuilder customizeHttpClient(final HttpAsyncClientBuilder httpClientBuilder) {
143+
if (usePreemptiveAuth == false) {
144+
// disable preemptive auth by ignoring any authcache
145+
httpClientBuilder.disableAuthCaching();
146+
}
147+
148+
return httpClientBuilder.setDefaultCredentialsProvider(credentialsProvider);
149+
}
150+
});
151+
}
152+
153+
return restClientBuilder.build();
154+
}
155+
132156
@AfterClass
133157
public static void stopHttpServers() throws IOException {
134158
restClient.close();
@@ -159,7 +183,7 @@ public void testHeaders() throws IOException {
159183

160184
assertEquals(method, esResponse.getRequestLine().getMethod());
161185
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
162-
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
186+
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
163187
assertHeaders(defaultHeaders, requestHeaders, esResponse.getHeaders(), standardHeaders);
164188
for (final Header responseHeader : esResponse.getHeaders()) {
165189
String name = responseHeader.getName();
@@ -189,7 +213,41 @@ public void testGetWithBody() throws IOException {
189213
bodyTest("GET");
190214
}
191215

192-
private void bodyTest(String method) throws IOException {
216+
/**
217+
* Verify that credentials are sent on the first request with preemptive auth enabled (default when provided with credentials).
218+
*/
219+
public void testPreemptiveAuthEnabled() throws IOException {
220+
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
221+
222+
try (final RestClient restClient = createRestClient(true, true)) {
223+
for (final String method : methods) {
224+
final Response response = bodyTest(restClient, method);
225+
226+
assertThat(response.getHeader("Authorization"), startsWith("Basic"));
227+
}
228+
}
229+
}
230+
231+
/**
232+
* Verify that credentials are <em>not</em> sent on the first request with preemptive auth disabled.
233+
*/
234+
public void testPreemptiveAuthDisabled() throws IOException {
235+
final String[] methods = { "POST", "PUT", "GET", "DELETE" };
236+
237+
try (final RestClient restClient = createRestClient(true, false)) {
238+
for (final String method : methods) {
239+
final Response response = bodyTest(restClient, method);
240+
241+
assertThat(response.getHeader("Authorization"), nullValue());
242+
}
243+
}
244+
}
245+
246+
private Response bodyTest(final String method) throws IOException {
247+
return bodyTest(restClient, method);
248+
}
249+
250+
private Response bodyTest(final RestClient restClient, final String method) throws IOException {
193251
String requestBody = "{ \"field\": \"value\" }";
194252
StringEntity entity = new StringEntity(requestBody);
195253
int statusCode = randomStatusCode(getRandom());
@@ -201,7 +259,9 @@ private void bodyTest(String method) throws IOException {
201259
}
202260
assertEquals(method, esResponse.getRequestLine().getMethod());
203261
assertEquals(statusCode, esResponse.getStatusLine().getStatusCode());
204-
assertEquals((pathPrefix.length() > 0 ? pathPrefix : "") + "/" + statusCode, esResponse.getRequestLine().getUri());
262+
assertEquals(pathPrefix + "/" + statusCode, esResponse.getRequestLine().getUri());
205263
assertEquals(requestBody, EntityUtils.toString(esResponse.getEntity()));
264+
265+
return esResponse;
206266
}
207267
}

client/rest/src/test/java/org/elasticsearch/client/RestClientSingleHostTests.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,12 @@
3434
import org.apache.http.client.methods.HttpPut;
3535
import org.apache.http.client.methods.HttpTrace;
3636
import org.apache.http.client.methods.HttpUriRequest;
37+
import org.apache.http.client.protocol.HttpClientContext;
3738
import org.apache.http.client.utils.URIBuilder;
3839
import org.apache.http.concurrent.FutureCallback;
3940
import org.apache.http.conn.ConnectTimeoutException;
4041
import org.apache.http.entity.StringEntity;
42+
import org.apache.http.impl.auth.BasicScheme;
4143
import org.apache.http.impl.nio.client.CloseableHttpAsyncClient;
4244
import org.apache.http.message.BasicHttpResponse;
4345
import org.apache.http.message.BasicStatusLine;
@@ -96,11 +98,13 @@ public class RestClientSingleHostTests extends RestClientTestCase {
9698
public void createRestClient() throws IOException {
9799
httpClient = mock(CloseableHttpAsyncClient.class);
98100
when(httpClient.<HttpResponse>execute(any(HttpAsyncRequestProducer.class), any(HttpAsyncResponseConsumer.class),
99-
any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
101+
any(HttpClientContext.class), any(FutureCallback.class))).thenAnswer(new Answer<Future<HttpResponse>>() {
100102
@Override
101103
public Future<HttpResponse> answer(InvocationOnMock invocationOnMock) throws Throwable {
102104
HttpAsyncRequestProducer requestProducer = (HttpAsyncRequestProducer) invocationOnMock.getArguments()[0];
103-
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[2];
105+
HttpClientContext context = (HttpClientContext) invocationOnMock.getArguments()[2];
106+
assertThat(context.getAuthCache().get(httpHost), instanceOf(BasicScheme.class));
107+
FutureCallback<HttpResponse> futureCallback = (FutureCallback<HttpResponse>) invocationOnMock.getArguments()[3];
104108
HttpUriRequest request = (HttpUriRequest)requestProducer.generateRequest();
105109
//return the desired status code or exception depending on the path
106110
if (request.getURI().getPath().equals("/soe")) {
@@ -156,7 +160,7 @@ public void testInternalHttpRequest() throws Exception {
156160
for (String httpMethod : getHttpMethods()) {
157161
HttpUriRequest expectedRequest = performRandomRequest(httpMethod);
158162
verify(httpClient, times(++times)).<HttpResponse>execute(requestArgumentCaptor.capture(),
159-
any(HttpAsyncResponseConsumer.class), any(FutureCallback.class));
163+
any(HttpAsyncResponseConsumer.class), any(HttpClientContext.class), any(FutureCallback.class));
160164
HttpUriRequest actualRequest = (HttpUriRequest)requestArgumentCaptor.getValue().generateRequest();
161165
assertEquals(expectedRequest.getURI(), actualRequest.getURI());
162166
assertEquals(expectedRequest.getClass(), actualRequest.getClass());

0 commit comments

Comments
 (0)