diff --git a/pyproject.toml b/pyproject.toml index 985692043..8a7cd9185 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,12 +37,11 @@ parse_xsd2 = "saml2.tools.parse_xsd2:main" [tool.poetry.dependencies] python = "^3.9" -cryptography = ">=3.1" +cryptography = ">=40.0" defusedxml = "*" importlib-metadata = {version = ">=1.7.0", python = "<3.8"} importlib-resources = {python = "<3.9", version = "*"} paste = {optional = true, version = "*"} -pyopenssl = "<24.3.0" python-dateutil = "*" pytz = "*" "repoze.who" = {optional = true, version = "*"} diff --git a/src/saml2/cert.py b/src/saml2/cert.py index c5f626601..1759b9b24 100644 --- a/src/saml2/cert.py +++ b/src/saml2/cert.py @@ -5,7 +5,11 @@ from os import remove from os.path import join -from OpenSSL import crypto +from cryptography import x509 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID import dateutil.parser import pytz @@ -36,7 +40,6 @@ def create_certificate( valid_to=315360000, sn=1, key_length=1024, - hash_alg="sha256", write_to_file=False, cert_dir="", cipher_passphrase=None, @@ -87,8 +90,6 @@ def create_certificate( is 1. :param key_length: Length of the key to be generated. Defaults to 1024. - :param hash_alg: Hash algorithm to use for the key. Default - is sha256. :param write_to_file: True if you want to write the certificate to a file. The method will then return a tuple with path to certificate file and @@ -131,49 +132,68 @@ def create_certificate( k_f = join(cert_dir, key_file) # create a key pair - k = crypto.PKey() - k.generate_key(crypto.TYPE_RSA, key_length) + k = rsa.generate_private_key( + public_exponent=65537, + key_size=key_length, + ) # create a self-signed cert - cert = crypto.X509() + builder = x509.CertificateBuilder() if request: - cert = crypto.X509Req() + builder = x509.CertificateSigningRequestBuilder() if len(cert_info["country_code"]) != 2: raise WrongInput("Country code must be two letters!") - cert.get_subject().C = cert_info["country_code"] - cert.get_subject().ST = cert_info["state"] - cert.get_subject().L = cert_info["city"] - cert.get_subject().O = cert_info["organization"] # noqa: E741 - cert.get_subject().OU = cert_info["organization_unit"] - cert.get_subject().CN = cn + subject_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, + cert_info["country_code"]), + x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, + cert_info["state"]), + x509.NameAttribute(NameOID.LOCALITY_NAME, + cert_info["city"]), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, + cert_info["organization"]), + x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, + cert_info["organization_unit"]), + x509.NameAttribute(NameOID.COMMON_NAME, cn), + ]) + builder = builder.subject_name(subject_name) if not request: - cert.set_serial_number(sn) - cert.gmtime_adj_notBefore(valid_from) # Valid before present time - cert.gmtime_adj_notAfter(valid_to) # 3 650 days - cert.set_issuer(cert.get_subject()) - cert.set_pubkey(k) - cert.sign(k, hash_alg) + now = datetime.datetime.now(datetime.UTC) + builder = builder.serial_number( + sn, + ).not_valid_before( + now + datetime.timedelta(seconds=valid_from), + ).not_valid_after( + now + datetime.timedelta(seconds=valid_to), + ).issuer_name( + subject_name, + ).public_key( + k.public_key(), + ) + cert = builder.sign(k, hashes.SHA256()) try: - if request: - tmp_cert = crypto.dump_certificate_request(crypto.FILETYPE_PEM, cert) - else: - tmp_cert = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) - tmp_key = None + tmp_cert = cert.public_bytes(serialization.Encoding.PEM) + key_encryption = None if cipher_passphrase is not None: passphrase = cipher_passphrase["passphrase"] if isinstance(cipher_passphrase["passphrase"], str): passphrase = passphrase.encode("utf-8") - tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k, cipher_passphrase["cipher"], passphrase) + key_encryption = serialization.BestAvailableEncryption(passphrase) else: - tmp_key = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) + key_encryption = serialization.NoEncryption() + tmp_key = k.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.TraditionalOpenSSL, + encryption_algorithm=key_encryption, + ) if write_to_file: - with open(c_f, "w") as fc: - fc.write(tmp_cert.decode("utf-8")) - with open(k_f, "w") as fk: - fk.write(tmp_key.decode("utf-8")) + with open(c_f, "wb") as fc: + fc.write(tmp_cert) + with open(k_f, "wb") as fk: + fk.write(tmp_key) return c_f, k_f return tmp_cert, tmp_key except Exception as ex: @@ -198,7 +218,6 @@ def create_cert_signed_certificate( sign_cert_str, sign_key_str, request_cert_str, - hash_alg="sha256", valid_from=0, valid_to=315360000, sn=1, @@ -222,8 +241,6 @@ def create_cert_signed_certificate( the requested certificate. If you only have a file use the method read_str_from_file to get a string representation. - :param hash_alg: Hash algorithm to use for the key. Default - is sha256. :param valid_from: When the certificate starts to be valid. Amount of seconds from when the certificate is generated. @@ -237,27 +254,29 @@ def create_cert_signed_certificate( :return: String representation of the signed certificate. """ - ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, sign_cert_str) - ca_key = None - if passphrase is not None: - ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str, passphrase) - else: - ca_key = crypto.load_privatekey(crypto.FILETYPE_PEM, sign_key_str) - req_cert = crypto.load_certificate_request(crypto.FILETYPE_PEM, request_cert_str) - - cert = crypto.X509() - cert.set_subject(req_cert.get_subject()) - cert.set_serial_number(sn) - cert.gmtime_adj_notBefore(valid_from) - cert.gmtime_adj_notAfter(valid_to) - cert.set_issuer(ca_cert.get_subject()) - cert.set_pubkey(req_cert.get_pubkey()) - cert.sign(ca_key, hash_alg) - - cert_dump = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) - if isinstance(cert_dump, str): - return cert_dump - return cert_dump.decode("utf-8") + if isinstance(sign_cert_str, str): + sign_cert_str = sign_cert_str.encode("utf-8") + ca_cert = x509.load_pem_x509_certificate(sign_cert_str) + ca_key = serialization.load_pem_private_key( + sign_key_str, password=passphrase) + req_cert = x509.load_pem_x509_csr(request_cert_str) + + now = datetime.datetime.now(datetime.UTC) + cert = x509.CertificateBuilder().subject_name( + req_cert.subject, + ).serial_number( + sn, + ).not_valid_before( + now + datetime.timedelta(seconds=valid_from), + ).not_valid_after( + now + datetime.timedelta(seconds=valid_to), + ).issuer_name( + ca_cert.subject, + ).public_key( + req_cert.public_key(), + ).sign(ca_key, hashes.SHA256()) + + return cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") def verify_chain(self, cert_chain_str_list, cert_str): """ @@ -276,13 +295,6 @@ def verify_chain(self, cert_chain_str_list, cert_str): cert_str = tmp_cert_str return (True, "Signed certificate is valid and correctly signed by CA " "certificate.") - def certificate_not_valid_yet(self, cert): - starts_to_be_valid = dateutil.parser.parse(cert.get_notBefore()) - now = pytz.UTC.localize(datetime.datetime.utcnow()) - if starts_to_be_valid < now: - return False - return True - def verify(self, signing_cert_str, cert_str): """ Verifies if a certificate is valid and signed by a given certificate. @@ -303,34 +315,34 @@ def verify(self, signing_cert_str, cert_str): Message = Why the validation failed. """ try: - ca_cert = crypto.load_certificate(crypto.FILETYPE_PEM, signing_cert_str) - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) - - if self.certificate_not_valid_yet(ca_cert): + if isinstance(signing_cert_str, str): + signing_cert_str = signing_cert_str.encode("utf-8") + if isinstance(cert_str, str): + cert_str = cert_str.encode("utf-8") + ca_cert = x509.load_pem_x509_certificate(signing_cert_str) + cert = x509.load_pem_x509_certificate(cert_str) + now = datetime.datetime.now(datetime.UTC) + + if ca_cert.not_valid_before_utc >= now: return False, "CA certificate is not valid yet." - if ca_cert.has_expired() == 1: + if ca_cert.not_valid_after_utc < now: return False, "CA certificate is expired." - if cert.has_expired() == 1: + if cert.not_valid_after_utc < now: return False, "The signed certificate is expired." - if self.certificate_not_valid_yet(cert): + if cert.not_valid_before_utc >= now: return False, "The signed certificate is not valid yet." - if ca_cert.get_subject().CN == cert.get_subject().CN: + if ca_cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME) == \ + cert.subject.get_attributes_for_oid(NameOID.COMMON_NAME): return False, ("CN may not be equal for CA certificate and the " "signed certificate.") - cert_algorithm = cert.get_signature_algorithm() - cert_algorithm = cert_algorithm.decode("ascii") - cert_str = cert_str.encode("ascii") - - cert_crypto = saml2.cryptography.pki.load_pem_x509_certificate(cert_str) - try: - crypto.verify(ca_cert, cert_crypto.signature, cert_crypto.tbs_certificate_bytes, cert_algorithm) + cert.verify_directly_issued_by(ca_cert) return True, "Signed certificate is valid and correctly signed by CA certificate." - except crypto.Error as e: + except (ValueError, TypeError, InvalidSignature) as e: return False, f"Certificate is incorrectly signed: {str(e)}" except Exception as e: return False, f"Certificate is not valid for an unknown reason. {str(e)}" @@ -352,8 +364,14 @@ def read_cert_from_file(cert_file, cert_type="pem"): data = fp.read() try: - cert = saml2.cryptography.pki.load_x509_certificate(data, cert_type) - pem_data = saml2.cryptography.pki.get_public_bytes_from_cert(cert) + cert = None + if cert_type == "pem": + cert = x509.load_pem_x509_certificate(data) + elif cert_type == "der": + cert = x509.load_der_x509_certificate(data) + else: + raise ValueError(f"cert-type {cert_type} not supported") + pem_data = cert.public_bytes(serialization.Encoding.PEM).decode("utf-8") except Exception as e: raise CertificateError(e) diff --git a/src/saml2/sigver.py b/src/saml2/sigver.py index f3af1ec99..98d11b1d1 100644 --- a/src/saml2/sigver.py +++ b/src/saml2/sigver.py @@ -28,7 +28,7 @@ from urllib import parse -from OpenSSL import crypto +from cryptography import x509 import pytz from saml2 import ExtensionElement @@ -383,14 +383,14 @@ def active_cert(key): """ try: cert_str = pem_format(key) - cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_str) + cert = x509.load_pem_x509_certificate(cert_str) except AttributeError: return False - now = pytz.UTC.localize(datetime.datetime.utcnow()) - valid_from = dateutil.parser.parse(cert.get_notBefore()) - valid_to = dateutil.parser.parse(cert.get_notAfter()) - active = not cert.has_expired() and valid_from <= now < valid_to + now = datetime.datetime.now(datetime.UTC) + valid_from = cert.not_valid_before_utc + valid_to = cert.not_valid_after_utc + active = valid_from <= now < valid_to return active