Skip to content

Add OpenSaml custom types to Saml2AuthenticatedPrincipal #10809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,7 @@ private static Object getXmlObjectValue(XMLObject xmlObject) {
if (xmlObject instanceof XSDateTime) {
return ((XSDateTime) xmlObject).getValue();
}
return null;
return xmlObject;
}

private static Saml2AuthenticationException createAuthenticationException(String code, String message,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,32 @@ public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
assertThat(principal.getSessionIndexes()).contains("session-index");
}

@Test
public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() {
XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(
TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME,
new TestCustomOpenSamlObject.CustomSamlObjectMarshaller());
XMLObjectProviderRegistrySupport.getUnmarshallerFactory().registerUnmarshaller(
TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME,
new TestCustomOpenSamlObject.CustomSamlObjectUnmarshaller());
Response response = response();
Assertion assertion = assertion();
List<AttributeStatement> attributes = TestOpenSamlObjects.customAttributeStatements();
assertion.getAttributeStatements().addAll(attributes);
TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
RELYING_PARTY_ENTITY_ID);
response.getAssertions().add(assertion);
Saml2AuthenticationToken token = token(response, verifying(registration()));
Authentication authentication = this.provider.authenticate(token);
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
TestCustomOpenSamlObject.CustomSamlObject customSamlObject;
customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) principal.getAttribute("Address").get(0);
assertThat(customSamlObject.getStreet()).isEqualTo("Test Street");
assertThat(customSamlObject.getStreetNumber()).isEqualTo("1");
assertThat(customSamlObject.getZIP()).isEqualTo("11111");
assertThat(customSamlObject.getCity()).isEqualTo("Test City");
}

@Test
public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() {
Response response = response();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
/*
* Copyright 2002-2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.security.saml2.provider.service.authentication;

import java.util.Collections;
import java.util.List;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import javax.xml.namespace.QName;

import net.shibboleth.utilities.java.support.xml.ElementSupport;
import org.opensaml.core.xml.AbstractXMLObject;
import org.opensaml.core.xml.AbstractXMLObjectBuilder;
import org.opensaml.core.xml.ElementExtensibleXMLObject;
import org.opensaml.core.xml.Namespace;
import org.opensaml.core.xml.XMLObject;
import org.opensaml.core.xml.io.AbstractXMLObjectMarshaller;
import org.opensaml.core.xml.io.AbstractXMLObjectUnmarshaller;
import org.opensaml.core.xml.io.UnmarshallingException;
import org.opensaml.core.xml.schema.XSAny;
import org.opensaml.core.xml.util.IndexedXMLObjectChildrenList;
import org.opensaml.saml.common.xml.SAMLConstants;
import org.opensaml.saml.saml2.core.AttributeValue;
import org.w3c.dom.Element;

public class TestCustomOpenSamlObject {

public interface CustomSamlObject extends ElementExtensibleXMLObject {

String TYPE_LOCAL_NAME = "CustomType";

String TYPE_CUSTOM_PREFIX = "custom";

String CUSTOM_NS = "https://custom.com/schema/custom";

/** QName of the CustomType type. */
QName TYPE_NAME = new QName(CUSTOM_NS, TYPE_LOCAL_NAME, TYPE_CUSTOM_PREFIX);

String getStreet();

String getStreetNumber();

String getZIP();

String getCity();

}

public static class CustomSamlObjectImpl extends AbstractXMLObject
implements TestCustomOpenSamlObject.CustomSamlObject {

@Nonnull
private IndexedXMLObjectChildrenList<XMLObject> unknownXMLObjects;

/**
* Constructor.
* @param namespaceURI the namespace the element is in
* @param elementLocalName the local name of the XML element this Object
* represents
* @param namespacePrefix the prefix for the given namespace
*/
protected CustomSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName,
@Nullable String namespacePrefix) {
super(namespaceURI, elementLocalName, namespacePrefix);
super.getNamespaceManager().registerNamespaceDeclaration(new Namespace(CUSTOM_NS, TYPE_CUSTOM_PREFIX));
this.unknownXMLObjects = new IndexedXMLObjectChildrenList<>(this);
}

@Nonnull
@Override
public List<XMLObject> getUnknownXMLObjects() {
return this.unknownXMLObjects;
}

@Nonnull
@Override
public List<XMLObject> getUnknownXMLObjects(@Nonnull QName typeOrName) {
return (List<XMLObject>) this.unknownXMLObjects.subList(typeOrName);
}

@Nullable
@Override
public List<XMLObject> getOrderedChildren() {
return Collections.unmodifiableList(this.unknownXMLObjects);
}

@Override
public String getStreet() {
return ((XSAny) getOrderedChildren().get(0)).getTextContent();
}

@Override
public String getStreetNumber() {
return ((XSAny) getOrderedChildren().get(1)).getTextContent();
}

@Override
public String getZIP() {
return ((XSAny) getOrderedChildren().get(2)).getTextContent();
}

@Override
public String getCity() {
return ((XSAny) getOrderedChildren().get(3)).getTextContent();
}

}

public static class CustomSamlObjectBuilder
extends AbstractXMLObjectBuilder<TestCustomOpenSamlObject.CustomSamlObject> {

@Nonnull
@Override
public TestCustomOpenSamlObject.CustomSamlObject buildObject(@Nullable String namespaceURI,
@Nonnull String localName, @Nullable String namespacePrefix) {
return new TestCustomOpenSamlObject.CustomSamlObjectImpl(namespaceURI, localName, namespacePrefix);
}

}

public static class CustomSamlObjectMarshaller extends AbstractXMLObjectMarshaller {

public CustomSamlObjectMarshaller() {
super();
}

@Override
protected void marshallElementContent(@Nonnull XMLObject xmlObject, @Nonnull Element domElement) {
final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) xmlObject;

for (XMLObject object : customSamlObject.getOrderedChildren()) {
ElementSupport.appendChildElement(domElement, object.getDOM());
}
}

}

public static class CustomSamlObjectUnmarshaller extends AbstractXMLObjectUnmarshaller {

public CustomSamlObjectUnmarshaller() {
super();
}

@Override
protected void processChildElement(@Nonnull XMLObject parentXMLObject, @Nonnull XMLObject childXMLObject)
throws UnmarshallingException {
final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) parentXMLObject;
super.processChildElement(customSamlObject, childXMLObject);
customSamlObject.getUnknownXMLObjects().add(childXMLObject);
}

@Nonnull
@Override
protected XMLObject buildXMLObject(@Nonnull Element domElement) {
return new TestCustomOpenSamlObject.CustomSamlObjectImpl(SAMLConstants.SAML20_NS,
AttributeValue.DEFAULT_ELEMENT_LOCAL_NAME,
TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,37 @@ static Attribute attribute(String name, String value) {
return attribute;
}

static List<AttributeStatement> customAttributeStatements() {
List<AttributeStatement> attributeStatements = new ArrayList<>();
AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();
AttributeBuilder attributeBuilder = new AttributeBuilder();
Attribute attribute = attributeBuilder.buildObject();
attribute.setName("Address");
TestCustomOpenSamlObject.CustomSamlObject samlObject = new TestCustomOpenSamlObject.CustomSamlObjectBuilder()
.buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME);
XSAny street = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "Street",
TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
street.setTextContent("Test Street");
samlObject.getUnknownXMLObjects().add(street);
XSAny streetNumber = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS,
"Number", TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
streetNumber.setTextContent("1");
samlObject.getUnknownXMLObjects().add(streetNumber);
XSAny zip = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "ZIP",
TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
zip.setTextContent("11111");
samlObject.getUnknownXMLObjects().add(zip);
XSAny city = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "City",
TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
city.setTextContent("Test City");
samlObject.getUnknownXMLObjects().add(city);
attribute.getAttributeValues().add(samlObject);
AttributeStatement attributeStatement = attributeStatementBuilder.buildObject();
attributeStatement.getAttributes().add(attribute);
attributeStatements.add(attributeStatement);
return attributeStatements;
}

static List<AttributeStatement> attributeStatements() {
List<AttributeStatement> attributeStatements = new ArrayList<>();
AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();
Expand Down