diff --git a/pyproject.toml b/pyproject.toml index 9536b461..247d3959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ exclude_lines = [ [tool.poetry] name = "cryptojwt" -version = "1.8.3" +version = "1.8.4" description = "Python implementation of JWT, JWE, JWS and JWK" authors = ["Roland Hedberg "] license = "Apache-2.0" diff --git a/src/cryptojwt/jwe/jwe_ec.py b/src/cryptojwt/jwe/jwe_ec.py index 3321a8e1..b8c9bc6b 100644 --- a/src/cryptojwt/jwe/jwe_ec.py +++ b/src/cryptojwt/jwe/jwe_ec.py @@ -213,7 +213,6 @@ def encrypt(self, key=None, iv="", cek="", **kwargs): return jwe.pack(parts=[iv, ctxt, tag]) def decrypt(self, token=None, **kwargs): - if isinstance(token, JWEnc): jwe = token else: diff --git a/src/cryptojwt/jwk/__init__.py b/src/cryptojwt/jwk/__init__.py index d8fd480a..fe4e9cc9 100644 --- a/src/cryptojwt/jwk/__init__.py +++ b/src/cryptojwt/jwk/__init__.py @@ -31,7 +31,6 @@ class JWK(object): def __init__( self, kty="", alg="", use="", kid="", x5c=None, x5t="", x5u="", key_ops=None, **kwargs ): - self.extra_args = kwargs # want kty, alg, use and kid to be strings diff --git a/src/cryptojwt/jwk/ec.py b/src/cryptojwt/jwk/ec.py index f1bc61ea..4acf9d9c 100644 --- a/src/cryptojwt/jwk/ec.py +++ b/src/cryptojwt/jwk/ec.py @@ -55,7 +55,7 @@ def ec_construct_public(num): def ec_construct_private(num): """ - Given a set of values on public and private attributes build a elliptic + Given a set of values on public and private attributes build an elliptic curve private key instance. :param num: A dictionary with public and private attributes and their values diff --git a/src/cryptojwt/jwk/okp.py b/src/cryptojwt/jwk/okp.py index 7d165f74..83159629 100644 --- a/src/cryptojwt/jwk/okp.py +++ b/src/cryptojwt/jwk/okp.py @@ -321,7 +321,6 @@ def cmp_keys(a, b, key_type): def new_okp_key(crv, kid="", **kwargs): - _key = OKP_CRV2PRIVATE[crv].generate() _rk = OKPKey(priv_key=_key, kid=kid, **kwargs) diff --git a/src/cryptojwt/jws/jws.py b/src/cryptojwt/jws/jws.py index 7da26fa3..f521cbc9 100644 --- a/src/cryptojwt/jws/jws.py +++ b/src/cryptojwt/jws/jws.py @@ -118,8 +118,9 @@ def sign_compact(self, keys=None, protected=None, **kwargs): key, xargs, _alg = self.alg_keys(keys, "sig", protected) - if "typ" in self: - xargs["typ"] = self["typ"] + for param in ["typ"]: + if param in self: + xargs[param] = self[param] _headers.update(xargs) jwt = JWSig(**_headers) diff --git a/src/cryptojwt/jwt.py b/src/cryptojwt/jwt.py index e212c772..01469340 100755 --- a/src/cryptojwt/jwt.py +++ b/src/cryptojwt/jwt.py @@ -4,6 +4,10 @@ import time import uuid from json import JSONDecodeError +from typing import Dict +from typing import List +from typing import MutableMapping +from typing import Optional from .exception import HeaderError from .exception import VerificationError @@ -79,25 +83,26 @@ class JWT: def __init__( self, key_jar=None, - iss="", - lifetime=0, - sign=True, - sign_alg="RS256", - encrypt=False, - enc_enc="A128GCM", - enc_alg="RSA-OAEP-256", - msg_cls=None, - iss2msg_cls=None, - skew=15, - allowed_sign_algs=None, - allowed_enc_algs=None, - allowed_enc_encs=None, - allowed_max_lifetime=None, - zip="", + iss: str = "", + lifetime: int = 0, + sign: bool = True, + sign_alg: str = "RS256", + encrypt: bool = False, + enc_enc: str = "A128GCM", + enc_alg: str = "RSA-OAEP-256", + msg_cls: Optional[MutableMapping] = None, + iss2msg_cls: Optional[Dict[str, str]] = None, + skew: Optional[int] = 15, + allowed_sign_algs: Optional[List[str]] = None, + allowed_enc_algs: Optional[List[str]] = None, + allowed_enc_encs: Optional[List[str]] = None, + allowed_max_lifetime: Optional[int] = None, + zip: Optional[str] = "", + typ2msg_cls: Optional[Dict] = None, ): self.key_jar = key_jar # KeyJar instance self.iss = iss # My identifier - self.lifetime = lifetime # default life time of the signature + self.lifetime = lifetime # default lifetime of the signature self.sign = sign # default signing or not self.alg = sign_alg # default signing algorithm self.encrypt = encrypt # default encrypting or not @@ -107,6 +112,7 @@ def __init__( self.with_jti = False # If a jti should be added # A map between issuers and the message classes they use self.iss2msg_cls = iss2msg_cls or {} + self.typ2msg_cls = typ2msg_cls or {} # Allowed time skew self.skew = skew # When verifying/decrypting @@ -206,16 +212,30 @@ def pack_key(self, issuer_id="", kid=""): return keys[0] # Might be more then one if kid == '' - def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None, **kwargs): + def message(self, signing_key, **kwargs): + return json.dumps(kwargs) + + def pack( + self, + payload: Optional[dict] = None, + kid: Optional[str] = "", + issuer_id: Optional[str] = "", + recv: Optional[str] = "", + aud: Optional[str] = None, + iat: Optional[int] = None, + jws_headers: Optional[Dict[str, str]] = None, + **kwargs + ) -> str: """ :param payload: Information to be carried as payload in the JWT :param kid: Key ID - :param issuer_id: The owner of the the keys that are to be used for signing + :param issuer_id: The owner of the keys that are to be used for signing :param recv: The intended immediate receiver :param aud: Intended audience for this JWS/JWE, not expected to contain the recipient. :param iat: Override issued at (default current timestamp) + :param jws_headers: JWS headers :param kwargs: Extra keyword arguments :return: A signed or signed and encrypted Json Web Token """ @@ -249,10 +269,12 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None, else: _key = None - _jws = JWS(json.dumps(_args), alg=self.alg) - _sjwt = _jws.sign_compact([_key]) + jws_headers = jws_headers or {} + + _jws = JWS(self.message(signing_key=_key, **_args), alg=self.alg) + _sjwt = _jws.sign_compact([_key], protected=jws_headers) else: - _sjwt = json.dumps(_args) + _sjwt = self.message(signing_key=None, **_args) if _encrypt: if not self.sign: @@ -300,8 +322,7 @@ def verify_profile(msg_cls, info, **kwargs): :return: The verified message as a msg_cls instance. """ _msg = msg_cls(**info) - if not _msg.verify(**kwargs): - raise VerificationError() + _msg.verify(**kwargs) return _msg def unpack(self, token, timestamp=None): @@ -373,11 +394,12 @@ def unpack(self, token, timestamp=None): if self.msg_cls: _msg_cls = self.msg_cls else: - try: - # try to find a issuer specific message class - _msg_cls = self.iss2msg_cls[_info["iss"]] - except KeyError: - _msg_cls = None + _msg_cls = None + # try to find an issuer specific message class + if "iss" in _info: + _msg_cls = self.iss2msg_cls.get(_info["iss"]) + if not _msg_cls and _jws_header and "typ" in _jws_header: + _msg_cls = self.typ2msg_cls.get(_jws_header["typ"]) timestamp = timestamp or utc_time_sans_frac() diff --git a/src/cryptojwt/jwx.py b/src/cryptojwt/jwx.py index b941a497..52c696b4 100644 --- a/src/cryptojwt/jwx.py +++ b/src/cryptojwt/jwx.py @@ -50,7 +50,7 @@ class JWx: :return: A class instance """ - args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit"] + args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit", "trust_chain"] def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs): self.msg = msg diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 451503cb..0fde736c 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -566,7 +566,6 @@ def update(self): :return: True if update was ok or False if we encountered an error during update. """ if self.source: - try: if self.local: if self.fileformat in ["jwks", "jwk"]: @@ -681,7 +680,7 @@ def append(self, key): @keys_writer def extend(self, keys): - """Add a key to the list of keys.""" + """Add a list of keys to the list of keys.""" self._keys.extend(keys) @keys_writer diff --git a/src/cryptojwt/key_jar.py b/src/cryptojwt/key_jar.py index c9ab1bb9..f3716b07 100755 --- a/src/cryptojwt/key_jar.py +++ b/src/cryptojwt/key_jar.py @@ -5,6 +5,8 @@ from requests import request +from cryptojwt.jwk import JWK + from .exception import IssuerNotFound from .jwe.jwe import alg2keytype as jwe_alg2keytype from .jws.utils import alg2keytype as jws_alg2keytype @@ -161,6 +163,11 @@ def add_kb(self, issuer_id, kb): issuer.add_kb(kb) self._issuers[issuer_id] = issuer + def add_keys(self, issuer_id: str, keys: List[JWK], **kwargs): + _kb = KeyBundle(**kwargs) + _kb.extend(keys) + self.add_kb(issuer_id, _kb) + @deprecated_alias(issuer="issuer_id", owner="issuer_id") def get(self, key_use, key_type="", issuer_id="", kid=None, **kwargs): """ @@ -475,7 +482,6 @@ def _add_key( no_kid_issuer=None, allow_missing_kid=False, ): - _issuer = self._get_issuer(issuer_id) if _issuer is None: logger.error('Issuer "{}" not in keyjar'.format(issuer_id)) diff --git a/tests/test_09_jwt.py b/tests/test_09_jwt.py index 0bb912fd..c8ec65ca 100755 --- a/tests/test_09_jwt.py +++ b/tests/test_09_jwt.py @@ -2,7 +2,6 @@ import pytest -from cryptojwt.exception import IssuerNotFound from cryptojwt.jws.exception import NoSuitableSigningKeys from cryptojwt.jwt import JWT from cryptojwt.jwt import VerificationError @@ -136,19 +135,6 @@ def test_jwt_pack_and_unpack_max_lifetime_exceeded(): _ = bob.unpack(_jwt) -def test_jwt_pack_and_unpack_max_lifetime_exceeded(): - lifetime = 3600 - alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) - payload = {"sub": "sub"} - _jwt = alice.pack(payload=payload) - - bob = JWT( - key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1 - ) - with pytest.raises(VerificationError): - _ = bob.unpack(_jwt) - - def test_jwt_pack_and_unpack_timestamp(): lifetime = 3600 alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime) @@ -258,6 +244,7 @@ class DummyMsg(object): def __init__(self, **kwargs): for key, val in kwargs.items(): setattr(self, key, val) + self.jws_headers = {} def verify(self, **kwargs): return True @@ -331,3 +318,17 @@ def test_eddsa_jwt(): kj.add_kb(ISSUER, KeyBundle(JWKS_DICT)) jwt = JWT(key_jar=kj) _ = jwt.unpack(JWT_TEST, timestamp=1655278809) + + +def test_extra_headers(): + _kj = KeyJar() + _kj.add_symmetric(ALICE, "hemligt ordsprak", usage=["sig"]) + + alice = JWT(key_jar=_kj, iss=ALICE, sign_alg="HS256") + payload = {"sub": "sub2"} + _jwt = alice.pack(payload=payload, jws_headers={"xtra": "header", "typ": "dummy"}) + + bob = JWT(key_jar=_kj, iss=BOB, sign_alg="HS256", typ2msg_cls={"dummy": DummyMsg}) + info = bob.unpack(_jwt) + assert isinstance(info, DummyMsg) + assert set(info.jws_header.keys()) == {"xtra", "typ", "alg", "kid"}