diff --git a/signedjson/key.py b/signedjson/key.py index 4d9b91d..e64e06a 100644 --- a/signedjson/key.py +++ b/signedjson/key.py @@ -100,6 +100,21 @@ def is_signing_algorithm_supported(key_id): return False +def decode_verify_key_base64(algorithm, version, key_base64): + # type: (str, str, str) -> VerifyKey + """Decode a base64 encoded verify key + Args: + algorithm (str): The algorithm the key is for (currently "ed25519"). + version (str): Identifies this key out of the keys for this entity. + key_base64 (str): Base64 encoded bytes of the key. + Returns: + A VerifyKey object. + """ + key_id = "%s:%s" % (algorithm, version) + key_bytes = decode_base64(key_base64) + return decode_verify_key_bytes(key_id, key_bytes) + + def decode_verify_key_bytes(key_id, key_bytes): # type: (str, bytes) -> VerifyKey """Decode a raw verify key @@ -146,8 +161,7 @@ def read_old_signing_keys(stream): keys = [] for line in stream: algorithm, version, expired, key_base64 = line.split() - key_name = "%s:%s" % (algorithm, version,) - key = decode_verify_key_bytes(key_name, decode_base64(key_base64)) + key = decode_verify_key_base64(algorithm, version, key_base64) key.expired = int(expired) keys.append(key) return keys diff --git a/tests/test_key.py b/tests/test_key.py index e163167..25ee4ff 100644 --- a/tests/test_key.py +++ b/tests/test_key.py @@ -1,17 +1,17 @@ - import unittest from signedjson.key import ( - generate_signing_key, - get_verify_key, decode_signing_key_base64, + decode_verify_key_base64, decode_verify_key_bytes, encode_signing_key_base64, - is_signing_algorithm_supported, encode_verify_key_base64, - read_signing_keys, + generate_signing_key, + get_verify_key, + is_signing_algorithm_supported, read_old_signing_keys, - write_signing_keys + read_signing_keys, + write_signing_keys, ) @@ -50,6 +50,25 @@ def test_decode_invalid_key(self): with self.assertRaises(Exception): decode_signing_key_base64("ed25519", self.version, "") + def test_decode_verify_key(self): + decoded_key = decode_verify_key_base64( + "ed25519", self.version, self.verify_key_base64 + ) + self.assertEquals(decoded_key.alg, "ed25519") + self.assertEquals(decoded_key.version, self.version) + + def test_decode_verify_key_invalid_base64(self): + with self.assertRaises(Exception): + decode_verify_key_base64("ed25519", self.version, "not base 64") + + def test_decode_verify_key_invalid_algorithm(self): + with self.assertRaises(Exception): + decode_verify_key_base64("not a valid alg", self.version, "") + + def test_decode_verify_key_invalid_key(self): + with self.assertRaises(Exception): + decode_verify_key_base64("ed25519", self.version, "") + def test_read_keys(self): stream = ["ed25519 %s %s" % (self.version, self.key_base64)] keys = read_signing_keys(stream) @@ -68,6 +87,7 @@ def test_write_signing_keys(self): class MockStream(object): def write(self, data): pass + write_signing_keys(MockStream(), [self.key])