Skip to content

Prepare release 1.3.0 #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cryptojwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
except ImportError:
pass

__version__ = "1.2.0"
__version__ = "1.3.0"

logger = logging.getLogger(__name__)

Expand Down
4 changes: 0 additions & 4 deletions src/cryptojwt/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ class UpdateFailed(KeyIOError):
pass


class UnknownKeytype(Invalid):
"""An unknown key type"""


class JWKException(JWKESTException):
pass

Expand Down
73 changes: 50 additions & 23 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import os
import time
from datetime import datetime
from functools import cmp_to_key

import requests
Expand Down Expand Up @@ -156,6 +157,7 @@ def __init__(
keys=None,
source="",
cache_time=300,
ignore_errors_period=0,
fileformat="jwks",
keytype="RSA",
keyusage=None,
Expand Down Expand Up @@ -188,6 +190,8 @@ def __init__(
self.remote = False
self.local = False
self.cache_time = cache_time
self.ignore_errors_period = ignore_errors_period
self.ignore_errors_until = None # UNIX timestamp of last error
self.time_out = 0
self.etag = ""
self.source = None
Expand Down Expand Up @@ -314,7 +318,11 @@ def do_local_jwk(self, filename):
Load a JWKS from a local file

:param filename: Name of the file from which the JWKS should be loaded
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local JWKS from %s", filename)
with open(filename) as input_file:
_info = json.load(input_file)
Expand All @@ -324,6 +332,7 @@ def do_local_jwk(self, filename):
self.do_keys([_info])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_local_der(self, filename, keytype, keyusage=None, kid=""):
"""
Expand All @@ -332,7 +341,11 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
:param filename: Name of the file
:param keytype: Presently 'rsa' and 'ec' supported
:param keyusage: encryption ('enc') or signing ('sig') or both
:return: True if load was successful or False if file hasn't been modified
"""
if not self._local_update_required():
return False

LOGGER.info("Reading local DER from %s", filename)
key_args = {}
_kty = keytype.lower()
Expand All @@ -355,16 +368,25 @@ def do_local_der(self, filename, keytype, keyusage=None, kid=""):
self.do_keys([key_args])
self.last_local = time.time()
self.time_out = self.last_local + self.cache_time
return True

def do_remote(self):
"""
Load a JWKS from a webpage.

:return: True or False if load was successful
:return: True if load was successful or False if remote hasn't been modified
"""
# if self.verify_ssl is not None:
# self.httpc_params["verify"] = self.verify_ssl

if self.ignore_errors_until and time.time() < self.ignore_errors_until:
LOGGER.warning(
"Not reading remote JWKS from %s (in error holddown until %s)",
self.source,
datetime.fromtimestamp(self.ignore_errors_until),
)
return False

LOGGER.info("Reading remote JWKS from %s", self.source)
try:
LOGGER.debug("KeyBundle fetch keys from: %s", self.source)
Expand All @@ -378,7 +400,10 @@ def do_remote(self):
LOGGER.error(err)
raise UpdateFailed(REMOTE_FAILED.format(self.source, str(err)))

if _http_resp.status_code == 200: # New content
load_successful = _http_resp.status_code == 200
not_modified = _http_resp.status_code == 304

if load_successful:
self.time_out = time.time() + self.cache_time

self.imp_jwks = self._parse_remote_response(_http_resp)
Expand All @@ -390,25 +415,27 @@ def do_remote(self):
self.do_keys(self.imp_jwks["keys"])
except KeyError:
LOGGER.error("No 'keys' keyword in JWKS")
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(MALFORMED.format(self.source))

if hasattr(_http_resp, "headers"):
headers = getattr(_http_resp, "headers")
self.last_remote = headers.get("last-modified") or headers.get("date")

elif _http_resp.status_code == 304: # Not modified
elif not_modified:
LOGGER.debug("%s not modified since %s", self.source, self.last_remote)
self.time_out = time.time() + self.cache_time

else:
LOGGER.warning(
"HTTP status %d reading remote JWKS from %s",
_http_resp.status_code,
self.source,
)
self.ignore_errors_until = time.time() + self.ignore_errors_period
raise UpdateFailed(REMOTE_FAILED.format(self.source, _http_resp.status_code))

self.last_updated = time.time()
return True
self.ignore_errors_until = None
return load_successful

def _parse_remote_response(self, response):
"""
Expand All @@ -433,23 +460,20 @@ def _parse_remote_response(self, response):
return None

def _uptodate(self):
res = False
if self.remote or self.local:
if time.time() > self.time_out:
if self.local and not self._local_update_required():
res = True
elif self.update():
res = True
return res
return self.update()
return False

def update(self):
"""
Reload the keys if necessary.

This is a forced update, will happen even if cache time has not elapsed.
Replaced keys will be marked as inactive and not removed.

:return: True if update was ok or False if we encountered an error during update.
"""
res = True # An update was successful
if self.source:
_old_keys = self._keys # just in case

Expand All @@ -459,24 +483,27 @@ def update(self):
try:
if self.local:
if self.fileformat in ["jwks", "jwk"]:
self.do_local_jwk(self.source)
updated = self.do_local_jwk(self.source)
elif self.fileformat == "der":
self.do_local_der(self.source, self.keytype, self.keyusage)
updated = self.do_local_der(self.source, self.keytype, self.keyusage)
elif self.remote:
res = self.do_remote()
updated = self.do_remote()
except Exception as err:
LOGGER.error("Key bundle update failed: %s", err)
self._keys = _old_keys # restore
return False

now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
if updated:
now = time.time()
for _key in _old_keys:
if _key not in self._keys:
if not _key.inactive_since: # If already marked don't mess
_key.inactive_since = now
self._keys.append(_key)
else:
self._keys = _old_keys

return res
return True

def get(self, typ="", only_active=True):
"""
Expand Down
45 changes: 44 additions & 1 deletion tests/test_03_key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cryptojwt.jwk.rsa import import_rsa_key_from_cert_file
from cryptojwt.jwk.rsa import new_rsa_key
from cryptojwt.key_bundle import KeyBundle
from cryptojwt.key_bundle import UpdateFailed
from cryptojwt.key_bundle import build_key_bundle
from cryptojwt.key_bundle import dump_jwks
from cryptojwt.key_bundle import init_key
Expand Down Expand Up @@ -566,6 +567,7 @@ def test_update_2():
ec_key = new_ec_key(crv="P-256", key_ops=["sign"])
_jwks = {"keys": [rsa_key.serialize(), ec_key.serialize()]}

time.sleep(0.5)
with open(fname, "w") as fp:
fp.write(json.dumps(_jwks))

Expand Down Expand Up @@ -1008,7 +1010,7 @@ def test_remote_not_modified():

with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source, status=304, headers=headers)
assert kb.do_remote()
assert not kb.do_remote()
assert kb.last_remote == headers.get("Last-Modified")
timeout2 = kb.time_out

Expand All @@ -1018,9 +1020,50 @@ def test_remote_not_modified():
kb2 = KeyBundle().load(exp)
assert kb2.source == source
assert len(kb2.keys()) == 3
assert len(kb2.active_keys()) == 3
assert len(kb2.get("rsa")) == 1
assert len(kb2.get("oct")) == 1
assert len(kb2.get("ec")) == 1
assert kb2.httpc_params == {"timeout": (2, 2)}
assert kb2.imp_jwks
assert kb2.last_updated


def test_ignore_errors_period():
source_good = "https://example.com/keys.json"
source_bad = "https://example.com/keys-bad.json"
ignore_errors_period = 1
# Mock response
with responses.RequestsMock() as rsps:
rsps.add(method="GET", url=source_good, json=JWKS_DICT, status=200)
rsps.add(method="GET", url=source_bad, json=JWKS_DICT, status=500)
httpc_params = {"timeout": (2, 2)} # connect, read timeouts in seconds
kb = KeyBundle(
source=source_good,
httpc=requests.request,
httpc_params=httpc_params,
ignore_errors_period=ignore_errors_period,
)
res = kb.do_remote()
assert res == True
assert kb.ignore_errors_until is None

# refetch, but fail by using a bad source
kb.source = source_bad
try:
res = kb.do_remote()
except UpdateFailed:
pass

# retry should fail silently as we're in holddown
res = kb.do_remote()
assert kb.ignore_errors_until is not None
assert res == False

# wait until holddown
time.sleep(ignore_errors_period + 1)

# try again
kb.source = source_good
res = kb.do_remote()
assert res == True
6 changes: 6 additions & 0 deletions tests/test_04_key_jar.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,12 @@ def test_aud(self):
keys = self.bob_keyjar.get_jwt_verify_keys(_jwt.jwt, no_kid_issuer=no_kid_issuer)
assert len(keys) == 1

def test_inactive_verify_key(self):
_jwt = factory(self.sjwt_b)
self.alice_keyjar.return_issuer("Bob")[0].mark_all_as_inactive()
keys = self.alice_keyjar.get_jwt_verify_keys(_jwt.jwt)
assert len(keys) == 0


def test_copy():
kj = KeyJar()
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ envlist = py{36,37,38},quality
[testenv]
passenv = CI TRAVIS TRAVIS_*
commands =
py.test --cov=cryptojwt --isort --black {posargs}
pytest -vvv -ra --cov=cryptojwt --isort --black {posargs}
codecov
extras = testing
deps =
Expand Down