Skip to content

implement draft CachedContent handler #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 173 additions & 0 deletions src/cryptojwt/cached_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import json
import logging
import os
import threading
import time
from abc import ABC
from abc import abstractmethod
from datetime import datetime
from typing import Callable
from typing import List
from typing import Optional

import requests

from .exception import UpdateFailed
from .utils import httpc_params_loader

DEFAULT_CACHE_TIME = 300


class NotModified(Exception):
pass


class CachedContent(ABC):
def __init__(
self,
source: str,
cache_time: int = DEFAULT_CACHE_TIME,
ignore_errors_period: int = 0,
deserializer: Optional[Callable] = None,
**kwargs,
):
self.logger = logging.getLogger(__name__).getChild(self.__class__.__name__)
self.lock = threading.Lock()
self.source = source
self.cache_time = cache_time
self.ignore_errors_period = ignore_errors_period
self.ignore_errors_until = None
self.last_update = None
self.next_update = 0
self.deserializer = deserializer or (lambda x: x)
self.content = None

def update(self, force: bool = False, fatal: bool = False) -> bool:
"""Update last cached content, return True if updated"""
if not force and time.time() < self.next_update:
return False
with self.lock:
if self.ignore_errors_until and time.time() < self.ignore_errors_until:
self.logger.warning(
"Skip updating content from %s (in error holddown until %s)",
self.source,
datetime.fromtimestamp(self.ignore_errors_until),
)
else:
try:
content = self.read_content(force=force)
self.content = self.deserializer(content)
self.last_update = time.time()
self.next_update = self.last_update + self.cache_time
self.ignore_errors_until = None
except NotModified:
return False
except Exception as exc:
print(exc)
self.logger.error("Content update %s failed: %s", self.source, exc)
if self.ignore_errors_period:
self.ignore_errors_until = time.time() + self.ignore_errors_period
if fatal:
raise UpdateFailed(str(exc))
return False
return True

def get(self, update: bool = False, force: bool = False, **kwargs):
"""Get last cached content, update if requested to"""
if update or force or self.content is None:
self.update(force=force, **kwargs)
return self.content

@abstractmethod
def read_content(self, force: bool = False):
pass

@classmethod
def from_source(cls, source: str, **kwargs):
if source.startswith("http://") or source.startswith("https://"):
return CachedContentHTTP(url=source, **kwargs)
else:
return CachedContentFile(filename=source, **kwargs)


class CachedContentFile(CachedContent):
def __init__(self, filename: str, **kwargs):
super().__init__(source=filename, **kwargs)
self.filename = filename
self.last_modified = None

def read_content(self, force: bool = False):
last_modified = os.stat(self.filename).st_mtime
if not force:
if last_modified == self.last_modified:
self.logger.debug("%s not modified since last refresh", self.filename)
raise NotModified
else:
self.logger.debug("Refresh forced")
self.logger.debug("%s modified", self.filename)
self.last_modified = last_modified
with open(self.filename, "rb") as file:
t1 = time.perf_counter()
content = file.read()
t2 = time.perf_counter()
self.logger.info(
"Load for %s took %.3f seconds",
self.filename,
t2 - t1,
extra={
"filename": self.filename,
"content_length": len(content),
"last_modified": self.last_modified,
"duration": t2 - t1,
},
)
return content


class CachedContentHTTP(CachedContent):
def __init__(
self,
url: str,
httpc=None,
httpc_params=None,
**kwargs,
):
super().__init__(source=url, **kwargs)
self.url = url
self.http_etag = None
self.http_date = None
self.httpc = httpc if httpc else requests.request
self.httpc_params = httpc_params_loader(httpc_params)

def read_content(self, force: bool = False):
"""Refresh content fetched via HTTP"""
httpc_params = self.httpc_params.copy()
if "headers" not in httpc_params:
httpc_params["headers"] = {}
if not force:
if self.http_etag:
httpc_params["headers"]["If-None-Match"] = self.http_etag
elif self.http_date:
httpc_params["headers"]["If-Modified-Since"] = self.http_date
t1 = time.perf_counter()
response = self.httpc("GET", self.url, **httpc_params)
t2 = time.perf_counter()
self.logger.info(
"GET for %s took %.3f seconds",
self.url,
t2 - t1,
extra={
"url": self.url,
"content_length": len(response.content),
"last_modified": response.headers.get("date"),
"http_status": response.status_code,
"duration": t2 - t1,
},
)
response.raise_for_status()
if response.status_code == 304:
raise NotModified
self.http_etag = response.headers.get("etag")
self.http_date = response.headers.get("date")
self.logger.debug("%s updated", self.url)
return response.content
83 changes: 83 additions & 0 deletions src/cryptojwt/jwk/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import json
from typing import List

from cryptojwt.jwk.x509 import import_private_key_from_pem_data
from cryptojwt.jwk.x509 import import_public_key_from_pem_data

from ..exception import UnknownKeyType
from . import JWK
from .ec import ECKey
from .hmac import SYMKey
from .jwk import key_from_jwk_dict
from .rsa import RSAKey
from .utils import harmonize_usage

K2C = {"RSA": RSAKey, "EC": ECKey, "oct": SYMKey}


def jwks_deserializer(data) -> List[JWK]:
"""Convert JWKS dictionary (as str or bytes) to JWK objects"""
keys = json.loads(data.decode() if isinstance(data, bytes) else data)
if isinstance(keys, dict) and "keys" in keys:
return [key_from_jwk_dict(k) for k in keys["keys"]]
elif isinstance(keys, list):
return [key_from_jwk_dict(k) for k in keys]
raise ValueError("Unknown JWKS format")


def der_private_deserializer(data, keytype, keyusage=None, kid=None) -> List[JWK]:
"""Convert PEM-encoded DER (as str or bytes) to JWK objects"""
key_dict = {}
_kty = keytype.lower()
if _kty in ["rsa", "ec"]:
key_dict["kty"] = _kty
_key = import_private_key_from_pem_data(data if isinstance(data, bytes) else data.encode())
key_dict["priv_key"] = _key
key_dict["pub_key"] = _key.public_key()
else:
raise NotImplementedError("No support for DER decoding of key type {}".format(_kty))
if not keyusage:
key_dict["use"] = ["enc", "sig"]
else:
key_dict["use"] = harmonize_usage(keyusage)
if kid:
key_dict["kid"] = kid
return jwk_dict_as_keys(key_dict)


def jwk_dict_as_keys(jwk_dict) -> List[JWK]:
"""
Return JWK dictionary as JWK objects

:param keys: JWK dictionary
:return: list of JWK objects
"""

res = []
kty = jwk_dict["kty"]

if kty.lower() in K2C:
jwk_dict["kty"] = kty.lower()
elif kty.upper() in K2C:
jwk_dict["kty"] = kty.upper()
else:
raise UnknownKeyType(jwk_dict)

try:
usage = harmonize_usage(jwk_dict["use"])
except KeyError:
usage = [""]
else:
del jwk_dict["use"]

kty = jwk_dict["kty"]
for use in usage:
try:
key = K2C[kty](use=use, **jwk_dict)
except KeyError:
raise UnknownKeyType(jwk_dict)
if not key.kid:
key.add_kid()
res.append(key)

return res
19 changes: 19 additions & 0 deletions src/cryptojwt/jwk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,22 @@ def sha512_digest(msg):
"SHA-384": sha384_digest,
"SHA-512": sha512_digest,
}

MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"}


def harmonize_usage(use):
"""

:param use:
:return: list of usage
"""
if isinstance(use, str):
return [MAP[use]]

if isinstance(use, list):
_ul = list(MAP.keys())
_us = {MAP[u] for u in use if u in _ul}
return list(_us)

return None
17 changes: 14 additions & 3 deletions src/cryptojwt/jwk/x509.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ def import_private_key_from_pem_file(filename, passphrase=None):
:return: A private key instance
"""
with open(filename, "rb") as key_file:
private_key = serialization.load_pem_private_key(
key_file.read(), password=passphrase, backend=default_backend()
)
private_key = import_private_key_from_pem_data(key_file.read(), passphrase=passphrase)
return private_key


def import_private_key_from_pem_data(pem_data, passphrase=None):
"""
Read a private key from a PEM data.

:param pem_data: Bytes of pem data
:param passphrase: A pass phrase to use to unpack the PEM file.
:return: A private key instance
"""
return serialization.load_pem_private_key(
pem_data, password=passphrase, backend=default_backend()
)


PREFIX = "-----BEGIN CERTIFICATE-----"
POSTFIX = "-----END CERTIFICATE-----"

Expand Down
20 changes: 1 addition & 19 deletions src/cryptojwt/key_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .jwk.jwk import import_jwk
from .jwk.rsa import RSAKey
from .jwk.rsa import new_rsa_key
from .jwk.utils import harmonize_usage
from .utils import as_unicode
from .utils import httpc_params_loader

Expand All @@ -47,25 +48,6 @@
# Make sure the keys are all uppercase
K2C = {"RSA": RSAKey, "EC": ECKey, "oct": SYMKey}

MAP = {"dec": "enc", "enc": "enc", "ver": "sig", "sig": "sig"}


def harmonize_usage(use):
"""

:param use:
:return: list of usage
"""
if isinstance(use, str):
return [MAP[use]]

if isinstance(use, list):
_ul = list(MAP.keys())
_us = {MAP[u] for u in use if u in _ul}
return list(_us)

return None


def rsa_init(spec):
"""
Expand Down
Loading