diff --git a/pyproject.toml b/pyproject.toml index e64f8e7..5c9de3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ [tool.poetry] name = "cryptojwt" -version = "1.9.3" +version = "1.9.4" description = "Python implementation of JWT, JWE, JWS and JWK" authors = ["Roland Hedberg "] license = "Apache-2.0" diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 0b286fa..172e564 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -40,6 +40,8 @@ LOGGER = logging.getLogger(__name__) +JWKS_CONTENT_TYPES = set(["application/json", "application/jwk-set+json"]) + # def raise_exception(excep, descr, error='service_error'): # _err = json.dumps({'error': error, 'error_description': descr}) # raise excep(_err, 'application/json') @@ -528,8 +530,8 @@ def _parse_remote_response(self, response): """ # Check if the content type is the right one. try: - if not check_content_type(response.headers["Content-Type"], "application/json"): - LOGGER.warning("Wrong Content_type (%s)", response.headers["Content-Type"]) + if not check_content_type(response.headers["Content-Type"], JWKS_CONTENT_TYPES): + LOGGER.warning("Wrong Content-Type (%s)", response.headers["Content-Type"]) except KeyError: pass diff --git a/src/cryptojwt/utils.py b/src/cryptojwt/utils.py index db41d64..4fa5250 100644 --- a/src/cryptojwt/utils.py +++ b/src/cryptojwt/utils.py @@ -8,7 +8,7 @@ import warnings from binascii import unhexlify from email.message import EmailMessage -from typing import List +from typing import List, Set, Union from cryptojwt.exception import BadSyntax @@ -262,12 +262,17 @@ def httpc_params_loader(httpc_params): return httpc_params -def check_content_type(content_type, mime_type): +def check_content_type(content_type: str, mime_type: Union[str, List[str], Set[str]]): """Return True if the content type contains the MIME type""" msg = EmailMessage() msg["content-type"] = content_type mt = msg.get_content_type() - return mime_type == mt + if isinstance(mime_type, str): + return mt == mime_type + elif isinstance(mime_type, (list, set)): + return mt in mime_type + else: + raise ValueError("Invalid MIME type argument") def is_compact_jws(token): diff --git a/tests/test_31_utils.py b/tests/test_31_utils.py index 86e98a5..d965a9c 100644 --- a/tests/test_31_utils.py +++ b/tests/test_31_utils.py @@ -1,3 +1,5 @@ +import pytest + from cryptojwt.utils import check_content_type @@ -15,3 +17,19 @@ def test_check_content_type(): ) is False ) + assert ( + check_content_type( + content_type="application/jwk-set+json;charset=UTF-8", + mime_type="application/application/jwk-set+json", + ) + is False + ) + assert ( + check_content_type( + content_type="application/jwk-set+json;charset=UTF-8", + mime_type=set(["application/application/jwk-set+json", "application/json"]), + ) + is False + ) + with pytest.raises(ValueError): + check_content_type(content_type="application/jwk-set+json;charset=UTF-8", mime_type=42)