diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index 40473bf6fe8ae..9ca77882aa567 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -46,6 +46,17 @@ public class SecurityRestFilter implements RestHandler { private final ThreadContext threadContext; private final boolean extractClientCertificate; + public enum ActionType { + Authentication("Authentication"), + SecondaryAuthentication("Secondary authentication"), + RequestHandling("Request handling"); + + private final String name; + ActionType(String name) { this.name = name; } + @Override + public String toString() { return name; } + } + public SecurityRestFilter(Settings settings, ThreadContext threadContext, AuthenticationService authenticationService, SecondaryAuthenticator secondaryAuthenticator, RestHandler restHandler, boolean extractClientCertificate) { this.settings = settings; @@ -89,16 +100,20 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c logger.trace("Found secondary authentication {} in REST request [{}]", secondaryAuthentication, requestUri); } RemoteHostHeader.process(request, threadContext); - restHandler.handleRequest(request, channel, client); + try { + restHandler.handleRequest(request, channel, client); + } catch (Exception e) { + handleException(ActionType.RequestHandling, request, channel, e); + } }, - e -> handleException("Secondary authentication", request, channel, e))); - }, e -> handleException("Authentication", request, channel, e))); + e -> handleException(ActionType.SecondaryAuthentication, request, channel, e))); + }, e -> handleException(ActionType.Authentication, request, channel, e))); } else { restHandler.handleRequest(request, channel, client); } } - private void handleException(String actionType, RestRequest request, RestChannel channel, Exception e) { + protected void handleException(ActionType actionType, RestRequest request, RestChannel channel, Exception e) { logger.debug(new ParameterizedMessage("{} failed for REST request [{}]", actionType, request.uri()), e); final RestStatus restStatus = ExceptionsHelper.status(e); try { @@ -109,11 +124,14 @@ private void handleException(String actionType, RestRequest request, RestChannel @Override public Map> filterHeaders(Map> headers) { - if (headers.containsKey("Warning")) { - headers = Maps.copyMapWithRemovedEntry(headers, "Warning"); - } - if (headers.containsKey("X-elastic-product")) { - headers = Maps.copyMapWithRemovedEntry(headers, "X-elastic-product"); + if (actionType != ActionType.RequestHandling + || (restStatus == RestStatus.UNAUTHORIZED || restStatus == RestStatus.FORBIDDEN)) { + if (headers.containsKey("Warning")) { + headers = Maps.copyMapWithRemovedEntry(headers, "Warning"); + } + if (headers.containsKey("X-elastic-product")) { + headers = Maps.copyMapWithRemovedEntry(headers, "X-elastic-product"); + } } return headers; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java new file mode 100644 index 0000000000000..467ae9d8eca90 --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterWarningHeadersTests.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.rest; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.collect.MapBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.xpack.core.security.authc.Authentication; +import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authc.support.SecondaryAuthenticator; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.mock.orig.Mockito.doThrow; +import static org.elasticsearch.test.ActionListenerUtils.anyActionListener; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class SecurityRestFilterWarningHeadersTests extends ESTestCase { + private ThreadContext threadContext; + private AuthenticationService authcService; + private SecondaryAuthenticator secondaryAuthenticator; + private RestHandler restHandler; + + @Override + protected boolean enableWarningsCheck() { + return false; + } + + @Before + public void init() throws Exception { + authcService = mock(AuthenticationService.class); + restHandler = mock(RestHandler.class); + threadContext = new ThreadContext(Settings.EMPTY); + secondaryAuthenticator = new SecondaryAuthenticator(Settings.EMPTY, threadContext, authcService); + } + + public void testResponseHeadersOnFailure() throws Exception { + MapBuilder> headers = new MapBuilder<>(); + headers.put("Warning", Collections.singletonList("Some warning header")); + headers.put("X-elastic-product", Collections.singletonList("Some product header")); + Map> afterHeaders; + + // Remove all the headers on authentication failures + afterHeaders = testProcessAuthenticationFailed(RestStatus.BAD_REQUEST, headers); + assertEquals(afterHeaders.size(), 0); + afterHeaders = testProcessAuthenticationFailed(RestStatus.INTERNAL_SERVER_ERROR, headers); + assertEquals(afterHeaders.size(), 0); + afterHeaders = testProcessAuthenticationFailed(RestStatus.UNAUTHORIZED, headers); + assertEquals(afterHeaders.size(), 0); + afterHeaders = testProcessAuthenticationFailed(RestStatus.FORBIDDEN, headers); + assertEquals(afterHeaders.size(), 0); + + // On rest handling failures only remove headers if rest status is UNAUTHORIZED or FORBIDDEN + afterHeaders = testProcessRestHandlingFailed(RestStatus.BAD_REQUEST, headers); + assertEquals(afterHeaders.size(), 2); + afterHeaders = testProcessRestHandlingFailed(RestStatus.INTERNAL_SERVER_ERROR, headers); + assertEquals(afterHeaders.size(), 2); + afterHeaders = testProcessRestHandlingFailed(RestStatus.UNAUTHORIZED, headers); + assertEquals(afterHeaders.size(), 0); + afterHeaders = testProcessRestHandlingFailed(RestStatus.FORBIDDEN, headers); + assertEquals(afterHeaders.size(), 0); + } + + private Map> testProcessRestHandlingFailed(RestStatus restStatus, MapBuilder> headers) + throws Exception { + RestChannel channel = mock(RestChannel.class); + SecurityRestFilter filter = new SecurityRestFilter(Settings.EMPTY, threadContext, authcService, secondaryAuthenticator, + restHandler, false); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + Authentication primaryAuthentication = mock(Authentication.class); + when(primaryAuthentication.encode()).thenReturn(randomAlphaOfLengthBetween(12, 36)); + doAnswer(i -> { + final Object[] arguments = i.getArguments(); + @SuppressWarnings("unchecked") + ActionListener callback = (ActionListener) arguments[arguments.length - 1]; + callback.onResponse(primaryAuthentication); + return null; + }).when(authcService).authenticate(eq(request), anyActionListener()); + Authentication secondaryAuthentication = mock(Authentication.class); + when(secondaryAuthentication.encode()).thenReturn(randomAlphaOfLengthBetween(12, 36)); + doAnswer(i -> { + final Object[] arguments = i.getArguments(); + @SuppressWarnings("unchecked") + ActionListener callback = (ActionListener) arguments[arguments.length - 1]; + callback.onResponse(secondaryAuthentication); + return null; + }).when(authcService).authenticate(eq(request), eq(false), anyActionListener()); + doThrow(new ElasticsearchStatusException("Rest handling failed", restStatus, "")) + .when(restHandler).handleRequest(request, channel, null); + when(channel.request()).thenReturn(request); + when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder()); + filter.handleRequest(request, channel, null); + ArgumentCaptor response = ArgumentCaptor.forClass(BytesRestResponse.class); + verify(channel).sendResponse(response.capture()); + RestResponse restResponse = response.getValue(); + return restResponse.filterHeaders(headers.immutableMap()); + } + + private Map> testProcessAuthenticationFailed(RestStatus restStatus, MapBuilder> headers) + throws Exception { + RestChannel channel = mock(RestChannel.class); + SecurityRestFilter filter = new SecurityRestFilter(Settings.EMPTY, threadContext, authcService, secondaryAuthenticator, + restHandler, false); + RestRequest request = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); + doAnswer((i) -> { + ActionListener callback = (ActionListener) i.getArguments()[1]; + callback.onFailure(new ElasticsearchStatusException("Authentication failed", restStatus, "")); + return Void.TYPE; + }).when(authcService).authenticate(eq(request), anyActionListener()); + when(channel.request()).thenReturn(request); + when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder()); + filter.handleRequest(request, channel, null); + ArgumentCaptor response = ArgumentCaptor.forClass(BytesRestResponse.class); + verify(channel).sendResponse(response.capture()); + RestResponse restResponse = response.getValue(); + return restResponse.filterHeaders(headers.immutableMap()); + } +}