Skip to content

Changes necessary to allow the SD JWT package to build on this. #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 15, 2023
Merged
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ exclude_lines = [

[tool.poetry]
name = "cryptojwt"
version = "1.8.3"
version = "1.8.4"
description = "Python implementation of JWT, JWE, JWS and JWK"
authors = ["Roland Hedberg <[email protected]>"]
license = "Apache-2.0"
Expand Down
1 change: 0 additions & 1 deletion src/cryptojwt/jwe/jwe_ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def encrypt(self, key=None, iv="", cek="", **kwargs):
return jwe.pack(parts=[iv, ctxt, tag])

def decrypt(self, token=None, **kwargs):

if isinstance(token, JWEnc):
jwe = token
else:
Expand Down
1 change: 0 additions & 1 deletion src/cryptojwt/jwk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class JWK(object):
def __init__(
self, kty="", alg="", use="", kid="", x5c=None, x5t="", x5u="", key_ops=None, **kwargs
):

self.extra_args = kwargs

# want kty, alg, use and kid to be strings
Expand Down
2 changes: 1 addition & 1 deletion src/cryptojwt/jwk/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def ec_construct_public(num):

def ec_construct_private(num):
"""
Given a set of values on public and private attributes build a elliptic
Given a set of values on public and private attributes build an elliptic
curve private key instance.

:param num: A dictionary with public and private attributes and their values
Expand Down
1 change: 0 additions & 1 deletion src/cryptojwt/jwk/okp.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,6 @@ def cmp_keys(a, b, key_type):


def new_okp_key(crv, kid="", **kwargs):

_key = OKP_CRV2PRIVATE[crv].generate()

_rk = OKPKey(priv_key=_key, kid=kid, **kwargs)
Expand Down
5 changes: 3 additions & 2 deletions src/cryptojwt/jws/jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,9 @@ def sign_compact(self, keys=None, protected=None, **kwargs):

key, xargs, _alg = self.alg_keys(keys, "sig", protected)

if "typ" in self:
xargs["typ"] = self["typ"]
for param in ["typ"]:
if param in self:
xargs[param] = self[param]

_headers.update(xargs)
jwt = JWSig(**_headers)
Expand Down
78 changes: 50 additions & 28 deletions src/cryptojwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import time
import uuid
from json import JSONDecodeError
from typing import Dict
from typing import List
from typing import MutableMapping
from typing import Optional

from .exception import HeaderError
from .exception import VerificationError
Expand Down Expand Up @@ -79,25 +83,26 @@ class JWT:
def __init__(
self,
key_jar=None,
iss="",
lifetime=0,
sign=True,
sign_alg="RS256",
encrypt=False,
enc_enc="A128GCM",
enc_alg="RSA-OAEP-256",
msg_cls=None,
iss2msg_cls=None,
skew=15,
allowed_sign_algs=None,
allowed_enc_algs=None,
allowed_enc_encs=None,
allowed_max_lifetime=None,
zip="",
iss: str = "",
lifetime: int = 0,
sign: bool = True,
sign_alg: str = "RS256",
encrypt: bool = False,
enc_enc: str = "A128GCM",
enc_alg: str = "RSA-OAEP-256",
msg_cls: Optional[MutableMapping] = None,
iss2msg_cls: Optional[Dict[str, str]] = None,
skew: Optional[int] = 15,
allowed_sign_algs: Optional[List[str]] = None,
allowed_enc_algs: Optional[List[str]] = None,
allowed_enc_encs: Optional[List[str]] = None,
allowed_max_lifetime: Optional[int] = None,
zip: Optional[str] = "",
typ2msg_cls: Optional[Dict] = None,
):
self.key_jar = key_jar # KeyJar instance
self.iss = iss # My identifier
self.lifetime = lifetime # default life time of the signature
self.lifetime = lifetime # default lifetime of the signature
self.sign = sign # default signing or not
self.alg = sign_alg # default signing algorithm
self.encrypt = encrypt # default encrypting or not
Expand All @@ -107,6 +112,7 @@ def __init__(
self.with_jti = False # If a jti should be added
# A map between issuers and the message classes they use
self.iss2msg_cls = iss2msg_cls or {}
self.typ2msg_cls = typ2msg_cls or {}
# Allowed time skew
self.skew = skew
# When verifying/decrypting
Expand Down Expand Up @@ -206,16 +212,30 @@ def pack_key(self, issuer_id="", kid=""):

return keys[0] # Might be more then one if kid == ''

def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None, **kwargs):
def message(self, signing_key, **kwargs):
return json.dumps(kwargs)

def pack(
self,
payload: Optional[dict] = None,
kid: Optional[str] = "",
issuer_id: Optional[str] = "",
recv: Optional[str] = "",
aud: Optional[str] = None,
iat: Optional[int] = None,
jws_headers: Optional[Dict[str, str]] = None,
**kwargs
) -> str:
"""

:param payload: Information to be carried as payload in the JWT
:param kid: Key ID
:param issuer_id: The owner of the the keys that are to be used for signing
:param issuer_id: The owner of the keys that are to be used for signing
:param recv: The intended immediate receiver
:param aud: Intended audience for this JWS/JWE, not expected to
contain the recipient.
:param iat: Override issued at (default current timestamp)
:param jws_headers: JWS headers
:param kwargs: Extra keyword arguments
:return: A signed or signed and encrypted Json Web Token
"""
Expand Down Expand Up @@ -249,10 +269,12 @@ def pack(self, payload=None, kid="", issuer_id="", recv="", aud=None, iat=None,
else:
_key = None

_jws = JWS(json.dumps(_args), alg=self.alg)
_sjwt = _jws.sign_compact([_key])
jws_headers = jws_headers or {}

_jws = JWS(self.message(signing_key=_key, **_args), alg=self.alg)
_sjwt = _jws.sign_compact([_key], protected=jws_headers)
else:
_sjwt = json.dumps(_args)
_sjwt = self.message(signing_key=None, **_args)

if _encrypt:
if not self.sign:
Expand Down Expand Up @@ -300,8 +322,7 @@ def verify_profile(msg_cls, info, **kwargs):
:return: The verified message as a msg_cls instance.
"""
_msg = msg_cls(**info)
if not _msg.verify(**kwargs):
raise VerificationError()
_msg.verify(**kwargs)
return _msg

def unpack(self, token, timestamp=None):
Expand Down Expand Up @@ -373,11 +394,12 @@ def unpack(self, token, timestamp=None):
if self.msg_cls:
_msg_cls = self.msg_cls
else:
try:
# try to find a issuer specific message class
_msg_cls = self.iss2msg_cls[_info["iss"]]
except KeyError:
_msg_cls = None
_msg_cls = None
# try to find an issuer specific message class
if "iss" in _info:
_msg_cls = self.iss2msg_cls.get(_info["iss"])
if not _msg_cls and _jws_header and "typ" in _jws_header:
_msg_cls = self.typ2msg_cls.get(_jws_header["typ"])

timestamp = timestamp or utc_time_sans_frac()

Expand Down
2 changes: 1 addition & 1 deletion src/cryptojwt/jwx.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class JWx:
:return: A class instance
"""

args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit"]
args = ["alg", "jku", "jwk", "x5u", "x5t", "x5c", "kid", "typ", "cty", "crit", "trust_chain"]

def __init__(self, msg=None, with_digest=False, httpc=None, **kwargs):
self.msg = msg
Expand Down
3 changes: 1 addition & 2 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,6 @@ def update(self):
:return: True if update was ok or False if we encountered an error during update.
"""
if self.source:

try:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
Expand Down Expand Up @@ -681,7 +680,7 @@ def append(self, key):

@keys_writer
def extend(self, keys):
"""Add a key to the list of keys."""
"""Add a list of keys to the list of keys."""
self._keys.extend(keys)

@keys_writer
Expand Down
8 changes: 7 additions & 1 deletion src/cryptojwt/key_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from requests import request

from cryptojwt.jwk import JWK

from .exception import IssuerNotFound
from .jwe.jwe import alg2keytype as jwe_alg2keytype
from .jws.utils import alg2keytype as jws_alg2keytype
Expand Down Expand Up @@ -161,6 +163,11 @@ def add_kb(self, issuer_id, kb):
issuer.add_kb(kb)
self._issuers[issuer_id] = issuer

def add_keys(self, issuer_id: str, keys: List[JWK], **kwargs):
_kb = KeyBundle(**kwargs)
_kb.extend(keys)
self.add_kb(issuer_id, _kb)

@deprecated_alias(issuer="issuer_id", owner="issuer_id")
def get(self, key_use, key_type="", issuer_id="", kid=None, **kwargs):
"""
Expand Down Expand Up @@ -475,7 +482,6 @@ def _add_key(
no_kid_issuer=None,
allow_missing_kid=False,
):

_issuer = self._get_issuer(issuer_id)
if _issuer is None:
logger.error('Issuer "{}" not in keyjar'.format(issuer_id))
Expand Down
29 changes: 15 additions & 14 deletions tests/test_09_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest

from cryptojwt.exception import IssuerNotFound
from cryptojwt.jws.exception import NoSuitableSigningKeys
from cryptojwt.jwt import JWT
from cryptojwt.jwt import VerificationError
Expand Down Expand Up @@ -136,19 +135,6 @@ def test_jwt_pack_and_unpack_max_lifetime_exceeded():
_ = bob.unpack(_jwt)


def test_jwt_pack_and_unpack_max_lifetime_exceeded():
lifetime = 3600
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
payload = {"sub": "sub"}
_jwt = alice.pack(payload=payload)

bob = JWT(
key_jar=BOB_KEY_JAR, iss=BOB, allowed_sign_algs=["RS256"], allowed_max_lifetime=lifetime - 1
)
with pytest.raises(VerificationError):
_ = bob.unpack(_jwt)


def test_jwt_pack_and_unpack_timestamp():
lifetime = 3600
alice = JWT(key_jar=ALICE_KEY_JAR, iss=ALICE, sign_alg="RS256", lifetime=lifetime)
Expand Down Expand Up @@ -258,6 +244,7 @@ class DummyMsg(object):
def __init__(self, **kwargs):
for key, val in kwargs.items():
setattr(self, key, val)
self.jws_headers = {}

def verify(self, **kwargs):
return True
Expand Down Expand Up @@ -331,3 +318,17 @@ def test_eddsa_jwt():
kj.add_kb(ISSUER, KeyBundle(JWKS_DICT))
jwt = JWT(key_jar=kj)
_ = jwt.unpack(JWT_TEST, timestamp=1655278809)


def test_extra_headers():
_kj = KeyJar()
_kj.add_symmetric(ALICE, "hemligt ordsprak", usage=["sig"])

alice = JWT(key_jar=_kj, iss=ALICE, sign_alg="HS256")
payload = {"sub": "sub2"}
_jwt = alice.pack(payload=payload, jws_headers={"xtra": "header", "typ": "dummy"})

bob = JWT(key_jar=_kj, iss=BOB, sign_alg="HS256", typ2msg_cls={"dummy": DummyMsg})
info = bob.unpack(_jwt)
assert isinstance(info, DummyMsg)
assert set(info.jws_header.keys()) == {"xtra", "typ", "alg", "kid"}