Skip to content

Commit ed892c4

Browse files
feat: Check if token is a JWT (#529)
1 parent ff93607 commit ed892c4

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

postgrest/base_client.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from httpx import BasicAuth, Timeout
77

8-
from .utils import AsyncClient, SyncClient, is_http_url
8+
from .utils import AsyncClient, SyncClient, is_http_url, is_valid_jwt
99

1010

1111
class BasePostgrestClient(ABC):
@@ -58,6 +58,8 @@ def auth(
5858
Bearer token is preferred if both ones are provided.
5959
"""
6060
if token:
61+
if not is_valid_jwt(token):
62+
ValueError("token must be a valid JWT authorization token")
6163
self.session.headers["Authorization"] = f"Bearer {token}"
6264
elif username:
6365
self.session.auth = BasicAuth(username, password)

postgrest/utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
22

3+
import re
34
from typing import Any, Type, TypeVar, cast, get_origin
45
from urllib.parse import urlparse
56

67
from httpx import AsyncClient # noqa: F401
78
from httpx import Client as BaseClient # noqa: F401
89

10+
BASE64URL_REGEX = r"^([a-z0-9_-]{4})*($|[a-z0-9_-]{3}$|[a-z0-9_-]{2}$)$"
11+
912

1013
class SyncClient(BaseClient):
1114
def aclose(self) -> None:
@@ -40,3 +43,26 @@ def get_origin_and_cast(typ: type[type[_T]]) -> type[_T]:
4043

4144
def is_http_url(url: str) -> bool:
4245
return urlparse(url).scheme in {"https", "http"}
46+
47+
48+
def is_valid_jwt(value: str) -> bool:
49+
"""Checks if value looks like a JWT, does not do any extra parsing."""
50+
if not isinstance(value, str):
51+
return False
52+
53+
# Remove trailing whitespaces if any.
54+
value = value.strip()
55+
56+
# Remove "Bearer " prefix if any.
57+
if value.startswith("Bearer "):
58+
value = value[7:]
59+
60+
# Valid JWT must have 2 dots (Header.Paylod.Signature)
61+
if value.count(".") != 2:
62+
return False
63+
64+
for part in value.split("."):
65+
if not re.search(BASE64URL_REGEX, part, re.IGNORECASE):
66+
return False
67+
68+
return True

0 commit comments

Comments
 (0)