Skip to content

Tests for finding out if a token is a compact JWS, json JWS or a JWE.… #140

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/cryptojwt/jws/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# import struct

from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import padding

Expand Down
98 changes: 95 additions & 3 deletions src/cryptojwt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

DEFAULT_HTTPC_TIMEOUT = 10


# ---------------------------------------------------------------------------
# Helper functions

Expand Down Expand Up @@ -193,7 +194,7 @@ def split_token(token):

def deser(val):
"""
Deserialize from a string representation of an long integer
Deserialize from a string representation of a long integer
to the python representation of a long integer.

:param val: The string representation of the long integer.
Expand All @@ -212,12 +213,12 @@ def modsplit(name):
if ":" in name:
_part = name.split(":")
if len(_part) != 2:
raise ValueError(f"Syntax error: {s}")
raise ValueError(f"Syntax error: {name}")
return _part[0], _part[1]

_part = name.split(".")
if len(_part) < 2:
raise ValueError(f"Syntax error: {s}")
raise ValueError(f"Syntax error: {name}")

return ".".join(_part[:-1]), _part[-1]

Expand Down Expand Up @@ -273,3 +274,94 @@ def check_content_type(content_type, mime_type):
msg["content-type"] = content_type
mt = msg.get_content_type()
return mime_type == mt


def is_compact_jws(token):
token = as_bytes(token)

try:
part = split_token(token)
except BadSyntax:
return False

# Should be three parts
if len(part) != 3:
return False

# All base64 encoded
try:
part = [b64d(p) for p in part]
except Exception:
return False

# header should be a JSON object, 'alg' most be one parameter
try:
_header = json.loads(part[0])
except Exception:
return False

if "alg" not in _header:
return False

return True


def is_jwe(token):
token = as_bytes(token)

try:
part = split_token(token)
except BadSyntax:
return False

# Should be five parts
if len(part) != 5:
return False

# All base64 encoded
try:
part = [b64d(p) for p in part]
except Exception:
return False

# header should be a JSON object, 'alg' most be one parameter
try:
_header = json.loads(part[0])
except Exception:
return False

if "alg" not in _header or "enc" not in _header:
return False

return True


def is_json_jws(token):
if isinstance(token, str):
try:
token = json.loads(token)
except Exception:
return False

for arg in ["payload", "signatures"]:
if arg not in token:
return False

if not isinstance(token["signatures"], list):
return False

for sign in token["signatures"]:
if not isinstance(sign, dict):
return False
if "signature" not in sign:
return False

return True


def is_jws(token):
if is_json_jws(token):
return "json"
elif is_compact_jws(token):
return "compact"
return False
52 changes: 51 additions & 1 deletion tests/test_06_jws.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec

from cryptojwt import as_unicode
from cryptojwt.exception import BadSignature
from cryptojwt.exception import UnknownAlgorithm
from cryptojwt.exception import WrongNumberOfParts
Expand All @@ -25,10 +26,13 @@
from cryptojwt.jws.utils import left_hash
from cryptojwt.jws.utils import parse_rsa_algorithm
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.utils import as_bytes
from cryptojwt.utils import b64d
from cryptojwt.utils import b64d_enc_dec
from cryptojwt.utils import b64e
from cryptojwt.utils import intarr2bin
from cryptojwt.utils import is_compact_jws
from cryptojwt.utils import is_json_jws

BASEDIR = os.path.abspath(os.path.dirname(__file__))

Expand Down Expand Up @@ -297,7 +301,6 @@ def full_path(local_file):
]
}


SIGJWKS = KeyBundle(JWKS_b)


Expand Down Expand Up @@ -1020,3 +1023,50 @@ def test_verify_json_missing_key():

# With both
assert JWS().verify_json(_jwt, keys=[vkeys[0], sym_key])


def test_is_compact_jws():
_header = {"foo": "bar", "alg": "HS384"}
_payload = "hello world"
_sym_key = SYMKey(key=b"My hollow echo chamber", alg="HS384")

_jwt = JWS(msg=_payload, alg="HS384").sign_compact(keys=[_sym_key])

assert is_compact_jws(_jwt)

# Faulty examples

# to few parts
assert is_compact_jws("abc.def") is False

# right number of parts but not base64

assert is_compact_jws("abc.def.ghi") is False

# not base64 illegal characters
assert is_compact_jws("abc.::::.ghi") is False

# Faulty header
_faulty_header = {"foo": "bar"} # alg is a MUST
_jwt = ".".join([as_unicode(b64e(as_bytes(json.dumps(_faulty_header)))), "def", "ghi"])
assert is_compact_jws(_jwt) is False


def test_is_json_jws():
ec_key = ECKey().load_key(P256())
sym_key = SYMKey(key=b"My hollow echo chamber", alg="HS384")

protected_headers_1 = {"foo": "bar", "alg": "ES256"}
unprotected_headers_1 = {"abc": "xyz"}
protected_headers_2 = {"foo": "bar", "alg": "HS384"}
unprotected_headers_2 = {"abc": "zeb"}
payload = "hello world"
_jwt = JWS(msg=payload).sign_json(
headers=[
(protected_headers_1, unprotected_headers_1),
(protected_headers_2, unprotected_headers_2),
],
keys=[ec_key, sym_key],
)

assert is_json_jws(_jwt)
9 changes: 9 additions & 0 deletions tests/test_07_jwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

__author__ = "rohe0002"

from cryptojwt.utils import is_jwe


def rndstr(size=16):
"""
Expand Down Expand Up @@ -717,3 +719,10 @@ def test_fernet_blake2s():
decrypter = encrypter
resp = decrypter.decrypt(_token)
assert resp == plain


def test_is_jwe():
encryption_key = SYMKey(use="enc", key="DukeofHazardpass", kid="some-key-id")
jwe = JWE(plain, alg="A128KW", enc="A128CBC-HS256")
_jwe = jwe.encrypt(keys=[encryption_key], kid="some-key-id")
assert is_jwe(_jwe)