diff --git a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java index 6254ee087d7..108f89d33b5 100644 --- a/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java +++ b/core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java @@ -23,6 +23,7 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Stream; import org.aopalliance.intercept.MethodInvocation; @@ -94,7 +95,7 @@ protected MethodSecurityExpressionOperations createSecurityExpressionRoot(Authen /** * Filters the {@code filterTarget} object (which must be either a collection, array, - * map or stream), by evaluating the supplied expression. + * map, stream or optional), by evaluating the supplied expression. *

* If a {@code Collection} or {@code Map} is used, the original instance will be * modified to contain the elements for which the permission expression evaluates to @@ -117,11 +118,14 @@ public Object filter(Object filterTarget, Expression filterExpression, Evaluatio if (filterTarget instanceof Stream) { return filterStream((Stream) filterTarget, filterExpression, ctx, rootObject); } + if (filterTarget instanceof Optional) { + return filterOptional((Optional) filterTarget, filterExpression, ctx, rootObject); + } throw new IllegalArgumentException( - "Filter target must be a collection, array, map or stream type, but was " + filterTarget); + "Filter target must be a collection, array, map, stream or optional type, but was " + filterTarget); } - private Object filterCollection(Collection filterTarget, Expression filterExpression, EvaluationContext ctx, + private Collection filterCollection(final Collection filterTarget, Expression filterExpression, EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { this.logger.debug(LogMessage.format("Filtering collection with %s elements", filterTarget.size())); List retain = new ArrayList<>(filterTarget.size()); @@ -140,7 +144,7 @@ private Object filterCollection(Collection filterTarget, Expression filte return filterTarget; } - private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx, + private Object[] filterArray(final Object[] filterTarget, Expression filterExpression, EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { List retain = new ArrayList<>(filterTarget.length); this.logger.debug(LogMessage.format("Filtering array with %s elements", filterTarget.length)); @@ -162,7 +166,7 @@ private Object filterArray(Object[] filterTarget, Expression filterExpression, E return filtered; } - private Object filterMap(final Map filterTarget, Expression filterExpression, EvaluationContext ctx, + private Map filterMap(final Map filterTarget, Expression filterExpression, EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { Map retain = new LinkedHashMap<>(filterTarget.size()); this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size())); @@ -178,14 +182,22 @@ private Object filterMap(final Map filterTarget, Expression filterE return filterTarget; } - private Object filterStream(final Stream filterTarget, Expression filterExpression, EvaluationContext ctx, + private Stream filterStream(final Stream filterTarget, Expression filterExpression, EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { - return filterTarget.filter((filterObject) -> { + return filterTarget.filter(filterObject -> { rootObject.setFilterObject(filterObject); return ExpressionUtils.evaluateAsBoolean(filterExpression, ctx); }).onClose(filterTarget::close); } + private Optional filterOptional(final Optional filterTarget, Expression filterExpression, + EvaluationContext ctx, MethodSecurityExpressionOperations rootObject) { + return filterTarget.filter(filterObject -> { + rootObject.setFilterObject(filterObject); + return ExpressionUtils.evaluateAsBoolean(filterExpression, ctx); + }); + } + /** * Sets the {@link AuthenticationTrustResolver} to be used. The default is * {@link AuthenticationTrustResolverImpl}. diff --git a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java index ce5d2b015e6..66d6c7cef05 100644 --- a/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java +++ b/core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -45,7 +46,7 @@ import static org.mockito.Mockito.verify; @ExtendWith(MockitoExtension.class) -public class DefaultMethodSecurityExpressionHandlerTests { +class DefaultMethodSecurityExpressionHandlerTests { private DefaultMethodSecurityExpressionHandler handler; @@ -74,12 +75,12 @@ public void cleanup() { } @Test - public void setTrustResolverNull() { + void setTrustResolverNull() { assertThatIllegalArgumentException().isThrownBy(() -> this.handler.setTrustResolver(null)); } @Test - public void createEvaluationContextCustomTrustResolver() { + void createEvaluationContextCustomTrustResolver() { setupMocks(); this.handler.setTrustResolver(this.trustResolver); Expression expression = this.handler.getExpressionParser().parseExpression("anonymous"); @@ -90,7 +91,7 @@ public void createEvaluationContextCustomTrustResolver() { @Test @SuppressWarnings("unchecked") - public void filterByKeyWhenUsingMapThenFiltersMap() { + void filterByKeyWhenUsingMapThenFiltersMap() { setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); @@ -99,16 +100,16 @@ public void filterByKeyWhenUsingMapThenFiltersMap() { Expression expression = this.handler.getExpressionParser().parseExpression("filterObject.key eq 'key2'"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); Object filtered = this.handler.filter(map, expression, context); - assertThat(filtered == map); + assertThat(filtered).isSameAs(map); Map result = ((Map) filtered); - assertThat(result.size() == 1); - assertThat(result).containsKey("key2"); - assertThat(result).containsValue("value2"); + assertThat(result).hasSize(1) + .containsOnlyKeys("key2") + .containsValue("value2"); } @Test @SuppressWarnings("unchecked") - public void filterByValueWhenUsingMapThenFiltersMap() { + void filterByValueWhenUsingMapThenFiltersMap() { setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); @@ -117,16 +118,16 @@ public void filterByValueWhenUsingMapThenFiltersMap() { Expression expression = this.handler.getExpressionParser().parseExpression("filterObject.value eq 'value3'"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); Object filtered = this.handler.filter(map, expression, context); - assertThat(filtered == map); + assertThat(filtered).isSameAs(map); Map result = ((Map) filtered); - assertThat(result.size() == 1); - assertThat(result).containsKey("key3"); - assertThat(result).containsValue("value3"); + assertThat(result).hasSize(1) + .containsOnlyKeys("key3") + .containsValue("value3"); } @Test @SuppressWarnings("unchecked") - public void filterByKeyAndValueWhenUsingMapThenFiltersMap() { + void filterByKeyAndValueWhenUsingMapThenFiltersMap() { setupMocks(); final Map map = new HashMap<>(); map.put("key1", "value1"); @@ -136,16 +137,16 @@ public void filterByKeyAndValueWhenUsingMapThenFiltersMap() { .parseExpression("(filterObject.key eq 'key1') or (filterObject.value eq 'value2')"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); Object filtered = this.handler.filter(map, expression, context); - assertThat(filtered == map); + assertThat(filtered).isSameAs(map); Map result = ((Map) filtered); - assertThat(result.size() == 2); - assertThat(result).containsKeys("key1", "key2"); - assertThat(result).containsValues("value1", "value2"); + assertThat(result).hasSize(2) + .containsOnlyKeys("key1", "key2") + .containsValues("value1", "value2"); } @Test @SuppressWarnings("unchecked") - public void filterWhenUsingStreamThenFiltersStream() { + void filterWhenUsingStreamThenFiltersStream() { setupMocks(); final Stream stream = Stream.of("1", "2", "3"); Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); @@ -157,16 +158,49 @@ public void filterWhenUsingStreamThenFiltersStream() { } @Test - public void filterStreamWhenClosedThenUpstreamGetsClosed() { + void filterStreamWhenClosedThenUpstreamGetsClosed() { setupMocks(); final Stream upstream = mock(Stream.class); doReturn(Stream.empty()).when(upstream).filter(any()); Expression expression = this.handler.getExpressionParser().parseExpression("true"); EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); - ((Stream) this.handler.filter(upstream, expression, context)).close(); + ((Stream) this.handler.filter(upstream, expression, context)).close(); verify(upstream).close(); } + @Test + @SuppressWarnings("unchecked") + void filterMatchingOptional() { + final Optional optional = Optional.of("1"); + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(optional, expression, context); + Optional result = ((Optional) filtered); + assertThat(result).isPresent().get().isEqualTo("1"); + } + + @Test + @SuppressWarnings("unchecked") + void filterNotMatchingOptional() { + final Optional optional = Optional.of("2"); + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(optional, expression, context); + Optional result = ((Optional) filtered); + assertThat(result).isNotPresent(); + } + + @Test + @SuppressWarnings("unchecked") + void filterEmptyOptional() { + final Optional optional = Optional.empty(); + Expression expression = this.handler.getExpressionParser().parseExpression("filterObject ne '2'"); + EvaluationContext context = this.handler.createEvaluationContext(this.authentication, this.methodInvocation); + Object filtered = this.handler.filter(optional, expression, context); + Optional result = ((Optional) filtered); + assertThat(result).isNotPresent(); + } + static class Foo { void bar() {