Skip to content

Commit e1eabfd

Browse files
authored
Merge pull request #80 from IdentityPython/develop
Changes as an effect of changing persistent storage model.
2 parents a54f653 + 99f3780 commit e1eabfd

File tree

8 files changed

+286
-107
lines changed

8 files changed

+286
-107
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ exclude_lines = [
2222

2323
[tool.poetry]
2424
name = "cryptojwt"
25-
version = "1.4.1"
25+
version = "1.5.0"
2626
description = "Python implementation of JWT, JWE, JWS and JWK"
2727
authors = ["Roland Hedberg <[email protected]>"]
2828
license = "Apache-2.0"

src/cryptojwt/key_bundle.py

Lines changed: 103 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import time
77
from datetime import datetime
88
from functools import cmp_to_key
9+
from typing import List
10+
from typing import Optional
911

1012
import requests
1113

@@ -24,7 +26,6 @@
2426
from .jwk.jwk import dump_jwk
2527
from .jwk.jwk import import_jwk
2628
from .jwk.rsa import RSAKey
27-
from .jwk.rsa import import_private_rsa_key_from_file
2829
from .jwk.rsa import new_rsa_key
2930
from .utils import as_unicode
3031

@@ -152,6 +153,26 @@ def ec_init(spec):
152153
class KeyBundle:
153154
"""The Key Bundle"""
154155

156+
params = {
157+
"cache_time": 0,
158+
"etag": "",
159+
"fileformat": "jwks",
160+
"httpc_params": {},
161+
"ignore_errors_period": 0,
162+
"ignore_errors_until": None,
163+
"ignore_invalid_keys": True,
164+
"imp_jwks": None,
165+
"keytype": "RSA",
166+
"keyusage": None,
167+
"last_local": None,
168+
"last_remote": None,
169+
"last_updated": 0,
170+
"local": False,
171+
"remote": False,
172+
"source": None,
173+
"time_out": 0,
174+
}
175+
155176
def __init__(
156177
self,
157178
keys=None,
@@ -189,22 +210,22 @@ def __init__(
189210
"""
190211

191212
self._keys = []
192-
self.remote = False
193-
self.local = False
194213
self.cache_time = cache_time
195-
self.ignore_errors_period = ignore_errors_period
196-
self.ignore_errors_until = None # UNIX timestamp of last error
197-
self.time_out = 0
198214
self.etag = ""
199-
self.source = None
200215
self.fileformat = fileformat.lower()
216+
self.ignore_errors_period = ignore_errors_period
217+
self.ignore_errors_until = None # UNIX timestamp of last error
218+
self.ignore_invalid_keys = ignore_invalid_keys
219+
self.imp_jwks = None
201220
self.keytype = keytype
202221
self.keyusage = keyusage
203-
self.imp_jwks = None
204-
self.last_updated = 0
205-
self.last_remote = None # HTTP Date of last remote update
206222
self.last_local = None # UNIX timestamp of last local update
207-
self.ignore_invalid_keys = ignore_invalid_keys
223+
self.last_remote = None # HTTP Date of last remote update
224+
self.last_updated = 0
225+
self.local = False
226+
self.remote = False
227+
self.source = None
228+
self.time_out = 0
208229

209230
if httpc:
210231
self.httpc = httpc
@@ -490,6 +511,7 @@ def update(self):
490511

491512
# reread everything
492513
self._keys = []
514+
updated = None
493515

494516
try:
495517
if self.local:
@@ -751,48 +773,68 @@ def difference(self, bundle):
751773

752774
return [k for k in self._keys if k not in bundle]
753775

754-
def dump(self):
755-
_keys = []
756-
for _k in self._keys:
757-
_ser = _k.to_dict()
758-
if _k.inactive_since:
759-
_ser["inactive_since"] = _k.inactive_since
760-
_keys.append(_ser)
761-
762-
res = {
763-
"keys": _keys,
764-
"fileformat": self.fileformat,
765-
"last_updated": self.last_updated,
766-
"last_remote": self.last_remote,
767-
"last_local": self.last_local,
768-
"httpc_params": self.httpc_params,
769-
"remote": self.remote,
770-
"local": self.local,
771-
"imp_jwks": self.imp_jwks,
772-
"time_out": self.time_out,
773-
"cache_time": self.cache_time,
774-
}
776+
def dump(self, exclude_attributes: Optional[List[str]] = None):
777+
if exclude_attributes is None:
778+
exclude_attributes = []
775779

776-
if self.source:
777-
res["source"] = self.source
780+
res = {}
781+
782+
if "keys" not in exclude_attributes:
783+
_keys = []
784+
for _k in self._keys:
785+
_ser = _k.to_dict()
786+
if _k.inactive_since:
787+
_ser["inactive_since"] = _k.inactive_since
788+
_keys.append(_ser)
789+
res["keys"] = _keys
790+
791+
for attr, default in self.params.items():
792+
if attr in exclude_attributes:
793+
continue
794+
val = getattr(self, attr)
795+
res[attr] = val
778796

779797
return res
780798

781799
def load(self, spec):
800+
"""
801+
Sets attributes according to a specification.
802+
Does not overwrite an existing attributes value with a default value.
803+
804+
:param spec: Dictionary with attributes and value to populate the instance with
805+
:return: The instance itself
806+
"""
782807
_keys = spec.get("keys", [])
783808
if _keys:
784809
self.do_keys(_keys)
785-
self.source = spec.get("source", None)
786-
self.fileformat = spec.get("fileformat", "jwks")
787-
self.last_updated = spec.get("last_updated", 0)
788-
self.last_remote = spec.get("last_remote", None)
789-
self.last_local = spec.get("last_local", None)
790-
self.remote = spec.get("remote", False)
791-
self.local = spec.get("local", False)
792-
self.imp_jwks = spec.get("imp_jwks", None)
793-
self.time_out = spec.get("time_out", 0)
794-
self.cache_time = spec.get("cache_time", 0)
795-
self.httpc_params = spec.get("httpc_params", {})
810+
811+
for attr, default in self.params.items():
812+
val = spec.get(attr)
813+
if val:
814+
setattr(self, attr, val)
815+
816+
return self
817+
818+
def flush(self):
819+
self._keys = []
820+
self.cache_time = (300,)
821+
self.etag = ""
822+
self.fileformat = "jwks"
823+
# self.httpc=None,
824+
self.httpc_params = (None,)
825+
self.ignore_errors_period = 0
826+
self.ignore_errors_until = None
827+
self.ignore_invalid_keys = True
828+
self.imp_jwks = None
829+
self.keytype = ("RSA",)
830+
self.keyusage = (None,)
831+
self.last_local = None # UNIX timestamp of last local update
832+
self.last_remote = None # HTTP Date of last remote update
833+
self.last_updated = 0
834+
self.local = False
835+
self.remote = False
836+
self.source = None
837+
self.time_out = 0
796838
return self
797839

798840

@@ -1246,3 +1288,19 @@ def init_key(filename, type, kid="", **kwargs):
12461288
_new_key = key_gen(type, kid=kid, **kwargs)
12471289
dump_jwk(filename, _new_key)
12481290
return _new_key
1291+
1292+
1293+
def key_by_alg(alg: str):
1294+
if alg.startswith("RS"):
1295+
return key_gen("RSA", alg="RS256")
1296+
elif alg.startswith("ES"):
1297+
if alg == "ES256":
1298+
return key_gen("EC", crv="P-256")
1299+
elif alg == "ES384":
1300+
return key_gen("EC", crv="P-384")
1301+
elif alg == "ES512":
1302+
return key_gen("EC", crv="P-521")
1303+
elif alg.startswith("HS"):
1304+
return key_gen("sym")
1305+
1306+
raise ValueError("Don't know who to create a key to use with '{}'".format(alg))

src/cryptojwt/key_issuer.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import logging
33
import os
4+
from typing import List
5+
from typing import Optional
46

57
from requests import request
68

@@ -15,13 +17,21 @@
1517

1618
__author__ = "Roland Hedberg"
1719

18-
1920
logger = logging.getLogger(__name__)
2021

2122

2223
class KeyIssuer(object):
2324
""" A key issuer instance contains a number of KeyBundles. """
2425

26+
params = {
27+
"ca_certs": None,
28+
"httpc_params": None,
29+
"keybundle_cls": KeyBundle,
30+
"name": "",
31+
"remove_after": 3600,
32+
"spec2key": None,
33+
}
34+
2535
def __init__(
2636
self,
2737
ca_certs=None,
@@ -45,14 +55,13 @@ def __init__(
4555

4656
self._bundles = []
4757

48-
self.keybundle_cls = keybundle_cls
49-
self.name = name
50-
51-
self.spec2key = {}
5258
self.ca_certs = ca_certs
53-
self.remove_after = remove_after
5459
self.httpc = httpc or request
5560
self.httpc_params = httpc_params or {}
61+
self.keybundle_cls = keybundle_cls
62+
self.name = name
63+
self.remove_after = remove_after
64+
self.spec2key = {}
5665

5766
def __repr__(self) -> str:
5867
return '<KeyIssuer "{}" {}>'.format(self.name, self.key_summary())
@@ -350,43 +359,57 @@ def __len__(self):
350359
nr += len(kb)
351360
return nr
352361

353-
def dump(self, exclude=None):
362+
def dump(self, exclude_attributes: Optional[List[str]] = None) -> dict:
354363
"""
355364
Returns the content as a dictionary.
356365
366+
:param exclude_attributes: List of attribute names for objects that should be ignored.
357367
:return: A dictionary
358368
"""
359369

360-
_bundles = []
361-
for kb in self._bundles:
362-
_bundles.append(kb.dump())
363-
364-
info = {
365-
"name": self.name,
366-
"bundles": _bundles,
367-
"keybundle_cls": qualified_name(self.keybundle_cls),
368-
"spec2key": self.spec2key,
369-
"ca_certs": self.ca_certs,
370-
"remove_after": self.remove_after,
371-
"httpc_params": self.httpc_params,
372-
}
370+
if exclude_attributes is None:
371+
exclude_attributes = []
372+
373+
info = {}
374+
for attr, default in self.params.items():
375+
if attr in exclude_attributes:
376+
continue
377+
val = getattr(self, attr)
378+
if attr == "keybundle_cls":
379+
val = qualified_name(val)
380+
info[attr] = val
381+
382+
if "bundles" not in exclude_attributes:
383+
_bundles = []
384+
for kb in self._bundles:
385+
_bundles.append(kb.dump(exclude_attributes=exclude_attributes))
386+
info["bundles"] = _bundles
387+
373388
return info
374389

375390
def load(self, info):
376391
"""
377392
378-
:param items: A list with the information
393+
:param items: A dictionary with the information to load
379394
:return:
380395
"""
381-
self.name = info["name"]
382-
self.keybundle_cls = importer(info["keybundle_cls"])
383-
self.spec2key = info["spec2key"]
384-
self.ca_certs = info["ca_certs"]
385-
self.remove_after = info["remove_after"]
386-
self.httpc_params = info["httpc_params"]
396+
for attr, default in self.params.items():
397+
val = info.get(attr)
398+
if val:
399+
if attr == "keybundle_cls":
400+
val = importer(val)
401+
setattr(self, attr, val)
402+
387403
self._bundles = [KeyBundle().load(val) for val in info["bundles"]]
388404
return self
389405

406+
def flush(self):
407+
for attr, default in self.params.items():
408+
setattr(self, attr, default)
409+
410+
self._bundles = []
411+
return self
412+
390413
def update(self):
391414
for kb in self._bundles:
392415
kb.update()

0 commit comments

Comments
 (0)