diff --git a/src/cryptojwt/jws/pss.py b/src/cryptojwt/jws/pss.py index f8e5a6d..961cc6c 100644 --- a/src/cryptojwt/jws/pss.py +++ b/src/cryptojwt/jws/pss.py @@ -14,10 +14,13 @@ class PSSSigner(Signer): def __init__(self, algorithm="SHA256"): if algorithm == "SHA256": self.hash_algorithm = hashes.SHA256 + self.salt_length = 32 elif algorithm == "SHA384": self.hash_algorithm = hashes.SHA384 + self.salt_length = 48 elif algorithm == "SHA512": self.hash_algorithm = hashes.SHA512 + self.salt_length = 64 else: raise Unsupported(f"algorithm: {algorithm}") @@ -36,7 +39,7 @@ def sign(self, msg, key): digest, padding.PSS( mgf=padding.MGF1(self.hash_algorithm()), - salt_length=padding.PSS.MAX_LENGTH, + salt_length=self.salt_length, ), utils.Prehashed(self.hash_algorithm()), ) @@ -48,7 +51,7 @@ def verify(self, msg, signature, key): :param msg: The message :param sig: A signature - :param key: A ec.EllipticCurvePublicKey to use for the verification. + :param key: A rsa._RSAPublicKey to use for the verification. :raises: BadSignature if the signature can't be verified. :return: True """ @@ -58,7 +61,7 @@ def verify(self, msg, signature, key): msg, padding.PSS( mgf=padding.MGF1(self.hash_algorithm()), - salt_length=padding.PSS.MAX_LENGTH, + salt_length=self.salt_length, ), self.hash_algorithm(), ) diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 172e564..f27f3f8 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -1134,23 +1134,24 @@ def sort_func(kd1, kd2): def order_key_defs(key_def): """ - Sort a set of key definitions. A key definition that defines more then - one usage type are splitted into as many definitions as the number of + Sort a set of key definitions. A key definition that defines more than + one usage type are split into as many definitions as the number of usage types specified. One key definition per usage type. - :param key_def: A set of key definitions + :param key_def: A set of key definitions. List of dictionaries :return: The set of definitions as a sorted list """ _int = [] # First make sure all defs only reference one usage for _def in key_def: - if len(_def["use"]) > 1: - for _use in _def["use"]: - _kd = _def.copy() - _kd["use"] = _use - _int.append(_kd) - else: - _int.append(_def) + if isinstance(_def, dict): + if len(_def["use"]) > 1: + for _use in _def["use"]: + _kd = _def.copy() + _kd["use"] = _use + _int.append(_kd) + else: + _int.append(_def) _int.sort(key=cmp_to_key(sort_func)) diff --git a/src/cryptojwt/key_jar.py b/src/cryptojwt/key_jar.py index 75c42f5..6c541d3 100755 --- a/src/cryptojwt/key_jar.py +++ b/src/cryptojwt/key_jar.py @@ -492,7 +492,9 @@ def _add_key( if _add_keys[0] not in keys: keys.append(_add_keys[0]) elif allow_missing_kid: - keys.extend(_add_keys) + for _key in _add_keys: + if _key and _key not in keys: + keys.append(_key) elif no_kid_issuer: try: allowed_kids = no_kid_issuer[issuer_id] diff --git a/tests/test_21_pss.py b/tests/test_21_pss.py new file mode 100644 index 0000000..7f99cac --- /dev/null +++ b/tests/test_21_pss.py @@ -0,0 +1,26 @@ +import json + +import pytest + +from cryptojwt.jwk.jwk import key_from_jwk_dict +from cryptojwt.jws.jws import JWS +import test_vector + + +@pytest.mark.parametrize("alg", ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512"]) +def test_jws_rsa_signer_and_verifier(alg): + _jwk_dict = json.loads(test_vector.json_rsa_priv_key) + _key = key_from_jwk_dict(_jwk_dict) + _key.alg = alg + _key.add_kid() + + json_header_rsa = json.loads(test_vector.test_header_rsa) + json_header_rsa["alg"] = alg + + # Sign + jws = JWS(msg=test_vector.test_payload, **json_header_rsa) + signed_token = jws.sign_compact([_key]) + + # Verify + verifier = JWS(alg=[alg]) + assert verifier.verify_compact(signed_token, [_key])