Skip to content

Commit 6321ffe

Browse files
authored
Improve PostgrestClient.auth() (#14)
1 parent 71456e5 commit 6321ffe

File tree

3 files changed

+59
-21
lines changed

3 files changed

+59
-21
lines changed

CHANGELOG.md

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
## CHANGELOG
1+
# CHANGELOG
22

33
### _Unreleased_
44

5-
#### Added
5+
#### Features
66

77
- Allow setting headers in `PostgrestClient`'s constructor
8+
- Improve `PostgrestClient.auth()` behavior
9+
10+
#### Internal
11+
12+
- Require Poetry >= 1.0.0
13+
- Update CI workflow
14+
- Use Dependabot
15+
- Update httpx to v0.19.0
816

917
### v0.4.0
1018

postgrest_py/client.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Union
1+
from typing import Dict, Optional, Union
22

33
from deprecation import deprecated
44
from httpx import AsyncClient, BasicAuth, Response
@@ -40,16 +40,25 @@ async def aclose(self) -> None:
4040

4141
def auth(
4242
self,
43-
token: str,
43+
token: Optional[str],
4444
*,
4545
username: Union[str, bytes, None] = None,
4646
password: Union[str, bytes] = "",
4747
):
48-
"""Authenticate the client with either bearer token or basic authentication."""
49-
if username:
48+
"""
49+
Authenticate the client with either bearer token or basic authentication.
50+
51+
Raise `ValueError` if neither authentication scheme is provided.
52+
Bearer token is preferred if both ones are provided.
53+
"""
54+
if token:
55+
self.session.headers["Authorization"] = f"Bearer {token}"
56+
elif username:
5057
self.session.auth = BasicAuth(username, password)
5158
else:
52-
self.session.headers["Authorization"] = f"Bearer {token}"
59+
raise ValueError(
60+
"Neither bearer token or basic authentication scheme is provided"
61+
)
5362
return self
5463

5564
def schema(self, schema: str):

tests/test_client.py

+35-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from httpx import BasicAuth
2+
from httpx import BasicAuth, Headers
33
from postgrest_py import PostgrestClient
44

55

@@ -9,33 +9,54 @@ async def postgrest_client():
99
yield client
1010

1111

12-
@pytest.mark.asyncio
13-
def test_constructor(postgrest_client):
14-
session = postgrest_client.session
12+
class TestConstructor:
13+
@pytest.mark.asyncio
14+
def test_simple(self, postgrest_client: PostgrestClient):
15+
session = postgrest_client.session
1516

16-
assert session.base_url == "https://example.com"
17-
default_headers = {
18-
"accept": "application/json",
19-
"content-type": "application/json",
20-
"accept-profile": "public",
21-
"content-profile": "public",
22-
}
23-
assert default_headers.items() <= session.headers.items()
17+
assert session.base_url == "https://example.com"
18+
headers = Headers(
19+
{
20+
"Accept": "application/json",
21+
"Content-Type": "application/json",
22+
"Accept-Profile": "public",
23+
"Content-Profile": "public",
24+
}
25+
)
26+
assert session.headers.items() >= headers.items()
27+
28+
@pytest.mark.asyncio
29+
async def test_custom_headers(self):
30+
async with PostgrestClient(
31+
"https://example.com", schema="pub", headers={"Custom-Header": "value"}
32+
) as client:
33+
session = client.session
34+
35+
assert session.base_url == "https://example.com"
36+
headers = Headers(
37+
{
38+
"Accept-Profile": "pub",
39+
"Content-Profile": "pub",
40+
"Custom-Header": "value",
41+
}
42+
)
43+
assert session.headers.items() >= headers.items()
2444

2545

2646
class TestAuth:
2747
@pytest.mark.asyncio
28-
def test_auth_token(self, postgrest_client):
48+
def test_auth_token(self, postgrest_client: PostgrestClient):
2949
postgrest_client.auth("s3cr3t")
3050
session = postgrest_client.session
3151

3252
assert session.headers["Authorization"] == "Bearer s3cr3t"
3353

3454
@pytest.mark.asyncio
35-
def test_auth_basic(self, postgrest_client):
55+
def test_auth_basic(self, postgrest_client: PostgrestClient):
3656
postgrest_client.auth(None, username="admin", password="s3cr3t")
3757
session = postgrest_client.session
3858

59+
assert isinstance(session.auth, BasicAuth)
3960
assert session.auth._auth_header == BasicAuth("admin", "s3cr3t")._auth_header
4061

4162

0 commit comments

Comments
 (0)