diff --git a/src/cryptojwt/cached_content.py b/src/cryptojwt/cached_content.py new file mode 100644 index 00000000..5782bf03 --- /dev/null +++ b/src/cryptojwt/cached_content.py @@ -0,0 +1,173 @@ +import json +import logging +import os +import threading +import time +from abc import ABC +from abc import abstractmethod +from datetime import datetime +from typing import Callable +from typing import List +from typing import Optional + +import requests + +from .exception import UpdateFailed +from .utils import httpc_params_loader + +DEFAULT_CACHE_TIME = 300 + + +class NotModified(Exception): + pass + + +class CachedContent(ABC): + def __init__( + self, + source: str, + cache_time: int = DEFAULT_CACHE_TIME, + ignore_errors_period: int = 0, + deserializer: Optional[Callable] = None, + **kwargs, + ): + self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__) + self.lock = threading.Lock() + self.source = source + self.cache_time = cache_time + self.ignore_errors_period = ignore_errors_period + self.ignore_errors_until = None + self.last_update = None + self.next_update = 0 + self.deserializer = deserializer or (lambda x: x) + self.content = None + + def update(self, force: bool = False, fatal: bool = False) -> bool: + """Update last cached content, return True if updated""" + if not force and time.time() < self.next_update: + return False + with self.lock: + if self.ignore_errors_until and time.time() < self.ignore_errors_until: + self.logger.warning( + "Skip updating content from %s (in error holddown until %s)", + self.source, + datetime.fromtimestamp(self.ignore_errors_until), + ) + else: + try: + content = self.read_content(force=force) + self.content = self.deserializer(content) + self.last_update = time.time() + self.next_update = self.last_update + self.cache_time + self.ignore_errors_until = None + except NotModified: + return False + except Exception as exc: + print(exc) + self.logger.error("Content update %s failed: %s", self.source, exc) + if self.ignore_errors_period: + self.ignore_errors_until = time.time() + self.ignore_errors_period + if fatal: + raise UpdateFailed(str(exc)) + return False + return True + + def get(self, update: bool = False, force: bool = False, **kwargs): + """Get last cached content, update if requested to""" + if update or force or self.content is None: + self.update(force=force, **kwargs) + return self.content + + @abstractmethod + def read_content(self, force: bool = False): + pass + + @classmethod + def from_source(cls, source: str, **kwargs): + if source.startswith("http://") or source.startswith("https://"): + return CachedContentHTTP(url=source, **kwargs) + else: + return CachedContentFile(filename=source, **kwargs) + + +class CachedContentFile(CachedContent): + def __init__(self, filename: str, **kwargs): + super().__init__(source=filename, **kwargs) + self.filename = filename + self.last_modified = None + + def read_content(self, force: bool = False): + last_modified = os.stat(self.filename).st_mtime + if not force: + if last_modified == self.last_modified: + self.logger.debug("%s not modified since last refresh", self.filename) + raise NotModified + else: + self.logger.debug("Refresh forced") + self.logger.debug("%s modified", self.filename) + self.last_modified = last_modified + with open(self.filename, "rb") as file: + t1 = time.perf_counter() + content = file.read() + t2 = time.perf_counter() + self.logger.info( + "Load for %s took %.3f seconds", + self.filename, + t2 - t1, + extra={ + "filename": self.filename, + "content_length": len(content), + "last_modified": self.last_modified, + "duration": t2 - t1, + }, + ) + return content + + +class CachedContentHTTP(CachedContent): + def __init__( + self, + url: str, + httpc=None, + httpc_params=None, + **kwargs, + ): + super().__init__(source=url, **kwargs) + self.url = url + self.http_etag = None + self.http_date = None + self.httpc = httpc if httpc else requests.request + self.httpc_params = httpc_params_loader(httpc_params) + + def read_content(self, force: bool = False): + """Refresh content fetched via HTTP""" + httpc_params = self.httpc_params.copy() + if "headers" not in httpc_params: + httpc_params["headers"] = {} + if not force: + if self.http_etag: + httpc_params["headers"]["If-None-Match"] = self.http_etag + elif self.http_date: + httpc_params["headers"]["If-Modified-Since"] = self.http_date + t1 = time.perf_counter() + response = self.httpc("GET", self.url, **httpc_params) + t2 = time.perf_counter() + self.logger.info( + "GET for %s took %.3f seconds", + self.url, + t2 - t1, + extra={ + "url": self.url, + "content_length": len(response.content), + "last_modified": response.headers.get("date"), + "http_status": response.status_code, + "duration": t2 - t1, + }, + ) + response.raise_for_status() + if response.status_code == 304: + raise NotModified + self.http_etag = response.headers.get("etag") + self.http_date = response.headers.get("date") + self.logger.debug("%s updated", self.url) + return response.content diff --git a/src/cryptojwt/jwk/serialization.py b/src/cryptojwt/jwk/serialization.py new file mode 100755 index 00000000..d32d043c --- /dev/null +++ b/src/cryptojwt/jwk/serialization.py @@ -0,0 +1,83 @@ +import json +from typing import List + +from cryptojwt.jwk.x509 import import_private_key_from_pem_data +from cryptojwt.jwk.x509 import import_public_key_from_pem_data + +from ..exception import UnknownKeyType +from . import JWK +from .ec import ECKey +from .hmac import SYMKey +from .jwk import key_from_jwk_dict +from .rsa import RSAKey +from .utils import harmonize_usage + +K2C = {"RSA": RSAKey, "EC": ECKey, "oct": SYMKey} + + +def jwks_deserializer(data) -> List[JWK]: + """Convert JWKS dictionary (as str or bytes) to JWK objects""" + keys = json.loads(data.decode() if isinstance(data, bytes) else data) + if isinstance(keys, dict) and "keys" in keys: + return [key_from_jwk_dict(k) for k in keys["keys"]] + elif isinstance(keys, list): + return [key_from_jwk_dict(k) for k in keys] + raise ValueError("Unknown JWKS format") + + +def der_private_deserializer(data, keytype, keyusage=None, kid=None) -> List[JWK]: + """Convert PEM-encoded DER (as str or bytes) to JWK objects""" + key_dict = {} + _kty = keytype.lower() + if _kty in ["rsa", "ec"]: + key_dict["kty"] = _kty + _key = import_private_key_from_pem_data(data if isinstance(data, bytes) else data.encode()) + key_dict["priv_key"] = _key + key_dict["pub_key"] = _key.public_key() + else: + raise NotImplementedError("No support for DER decoding of key type {}".format(_kty)) + if not keyusage: + key_dict["use"] = ["enc", "sig"] + else: + key_dict["use"] = harmonize_usage(keyusage) + if kid: + key_dict["kid"] = kid + return jwk_dict_as_keys(key_dict) + + +def jwk_dict_as_keys(jwk_dict) -> List[JWK]: + """ + Return JWK dictionary as JWK objects + + :param keys: JWK dictionary + :return: list of JWK objects + """ + + res = [] + kty = jwk_dict["kty"] + + if kty.lower() in K2C: + jwk_dict["kty"] = kty.lower() + elif kty.upper() in K2C: + jwk_dict["kty"] = kty.upper() + else: + raise UnknownKeyType(jwk_dict) + + try: + usage = harmonize_usage(jwk_dict["use"]) + except KeyError: + usage = [""] + else: + del jwk_dict["use"] + + kty = jwk_dict["kty"] + for use in usage: + try: + key = K2C[kty](use=use, **jwk_dict) + except KeyError: + raise UnknownKeyType(jwk_dict) + if not key.kid: + key.add_kid() + res.append(key) + + return res diff --git a/src/cryptojwt/jwk/utils.py b/src/cryptojwt/jwk/utils.py index 093978ab..afda5613 100644 --- a/src/cryptojwt/jwk/utils.py +++ b/src/cryptojwt/jwk/utils.py @@ -38,3 +38,22 @@ def sha512_digest(msg): "SHA-384": sha384_digest, "SHA-512": sha512_digest, } + +MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"} + + +def harmonize_usage(use): + """ + + :param use: + :return: list of usage + """ + if isinstance(use, str): + return [MAP[use]] + + if isinstance(use, list): + _ul = list(MAP.keys()) + _us = {MAP[u] for u in use if u in _ul} + return list(_us) + + return None diff --git a/src/cryptojwt/jwk/x509.py b/src/cryptojwt/jwk/x509.py index 00eead8c..9771c42d 100644 --- a/src/cryptojwt/jwk/x509.py +++ b/src/cryptojwt/jwk/x509.py @@ -35,12 +35,23 @@ def import_private_key_from_pem_file(filename, passphrase=None): :return: A private key instance """ with open(filename, "rb") as key_file: - private_key = serialization.load_pem_private_key( - key_file.read(), password=passphrase, backend=default_backend() - ) + private_key = import_private_key_from_pem_data(key_file.read(), passphrase=passphrase) return private_key +def import_private_key_from_pem_data(pem_data, passphrase=None): + """ + Read a private key from a PEM data. + + :param pem_data: Bytes of pem data + :param passphrase: A pass phrase to use to unpack the PEM file. + :return: A private key instance + """ + return serialization.load_pem_private_key( + pem_data, password=passphrase, backend=default_backend() + ) + + PREFIX = "-----BEGIN CERTIFICATE-----" POSTFIX = "-----END CERTIFICATE-----" diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index fee6c62d..246efd14 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -29,6 +29,7 @@ from .jwk.jwk import import_jwk from .jwk.rsa import RSAKey from .jwk.rsa import new_rsa_key +from .jwk.utils import harmonize_usage from .utils import as_unicode from .utils import httpc_params_loader @@ -47,25 +48,6 @@ # Make sure the keys are all uppercase K2C = {"RSA": RSAKey, "EC": ECKey, "oct": SYMKey} -MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"} - - -def harmonize_usage(use): - """ - - :param use: - :return: list of usage - """ - if isinstance(use, str): - return [MAP[use]] - - if isinstance(use, list): - _ul = list(MAP.keys()) - _us = {MAP[u] for u in use if u in _ul} - return list(_us) - - return None - def rsa_init(spec): """ diff --git a/tests/test_60_cached_content.py b/tests/test_60_cached_content.py new file mode 100644 index 00000000..f926b80b --- /dev/null +++ b/tests/test_60_cached_content.py @@ -0,0 +1,129 @@ +import functools +import json +import os + +import pytest + +from cryptojwt.cached_content import CachedContent +from cryptojwt.cached_content import CachedContentFile +from cryptojwt.cached_content import CachedContentHTTP +from cryptojwt.exception import UpdateFailed +from cryptojwt.jwk import JWK +from cryptojwt.jwk.ec import ECKey +from cryptojwt.jwk.rsa import RSAKey +from cryptojwt.jwk.serialization import der_private_deserializer +from cryptojwt.jwk.serialization import jwks_deserializer + +BASEDIR = os.path.abspath(os.path.dirname(__file__)) + +JWKS_FILE = "test_keys/jwk.json" + +RSA_PEM_FILE = "test_keys/rsa-2048-private.pem" +EC_PEM_FILE = "test_keys/ec-p256-private.pem" + +JWKS_URL = "https://raw.githubusercontent.com/IdentityPython/JWTConnect-Python-CryptoJWT/main/tests/test_keys/jwk.json" +BAD_URL = "https://httpstat.us/404" + + +def full_path(local_file): + return os.path.join(BASEDIR, local_file) + + +def test_local_text(): + deserializer = lambda x: x.decode() + cc = CachedContent.from_source(source=full_path(JWKS_FILE), deserializer=deserializer) + assert isinstance(cc, CachedContentFile) + assert cc.last_update is None + content = cc.get() + assert isinstance(content, str) + assert cc.last_update is not None + first_update = cc.last_update + for _ in range(0, 10): + content = cc.get(update=True) + assert isinstance(content, str) + assert cc.last_update == first_update + cc.get(force=True) + assert cc.last_update != first_update + + +def test_local_json(): + deserializer = json.loads + cc = CachedContent.from_source(source=full_path(JWKS_FILE), deserializer=deserializer) + assert isinstance(cc, CachedContentFile) + content = cc.get() + assert isinstance(content, dict) + + +def test_remote_text(): + deserializer = lambda x: x.decode() + cc = CachedContent.from_source(source=JWKS_URL, deserializer=deserializer) + assert isinstance(cc, CachedContentHTTP) + assert cc.last_update is None + content = cc.get() + assert isinstance(content, str) + assert cc.last_update is not None + first_update = cc.last_update + for _ in range(0, 10): + content = cc.get(update=True) + assert isinstance(content, str) + assert cc.last_update == first_update + cc.get(force=True) + assert cc.last_update != first_update + + +def test_remote_json(): + deserializer = json.loads + cc = CachedContent.from_source(source=JWKS_URL, deserializer=deserializer) + assert isinstance(cc, CachedContentHTTP) + content = cc.get() + assert isinstance(content, dict) + + +def test_local_jwks(): + cc = CachedContent.from_source(source=full_path(JWKS_FILE), deserializer=jwks_deserializer) + assert isinstance(cc, CachedContentFile) + keys = cc.get() + assert isinstance(keys, list) + for key in keys: + assert isinstance(key, JWK) + + +def test_local_pem_rsa_private(): + deserializer = functools.partial(der_private_deserializer, keytype="rsa") + cc = CachedContent.from_source(source=full_path(RSA_PEM_FILE), deserializer=deserializer) + assert isinstance(cc, CachedContentFile) + keys = cc.get() + assert isinstance(keys, list) + for key in keys: + assert isinstance(key, RSAKey) + + +def test_local_pem_ec(): + deserializer = functools.partial(der_private_deserializer, keytype="ec") + cc = CachedContent.from_source(source=full_path(EC_PEM_FILE), deserializer=deserializer) + assert isinstance(cc, CachedContentFile) + keys = cc.get() + assert isinstance(keys, list) + for key in keys: + assert isinstance(key, ECKey) + + +def test_remote_jwks(): + cc = CachedContent.from_source(source=JWKS_URL, deserializer=jwks_deserializer) + assert isinstance(cc, CachedContentHTTP) + keys = cc.get() + assert isinstance(keys, list) + for key in keys: + assert isinstance(key, JWK) + + +def test_remote_bad(): + cc = CachedContent.from_source(source=BAD_URL, ignore_errors_period=10) + assert isinstance(cc, CachedContentHTTP) + assert cc.last_update is None + with pytest.raises(UpdateFailed): + content = cc.get(fatal=True) + content = cc.get() + assert content is None + content = cc.get(fatal=True) + assert content is None