Skip to content

Commit 711b778

Browse files
committed
Refactor _get_implict_client. Add ability to load custom clients from get_client. Refactor env variable to read auth headers as dict.
1 parent 896ce80 commit 711b778

File tree

3 files changed

+70
-38
lines changed

3 files changed

+70
-38
lines changed

.github/workflows/test.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ jobs:
3232
JUDGE0_ATD_API_KEY: ${{ secrets.JUDGE0_ATD_API_KEY }}
3333
JUDGE0_RAPID_API_KEY: ${{ secrets.JUDGE0_RAPID_API_KEY }}
3434
JUDGE0_SULU_API_KEY: ${{ secrets.JUDGE0_SULU_API_KEY }}
35-
JUDGE0_TEST_API_KEY: ${{ secrets.JUDGE0_TEST_API_KEY }}
36-
JUDGE0_TEST_API_KEY_HEADER: ${{ secrets.JUDGE0_TEST_API_KEY_HEADER }}
37-
JUDGE0_TEST_CE_ENDPOINT: ${{ secrets.JUDGE0_TEST_CE_ENDPOINT }}
38-
JUDGE0_TEST_EXTRA_CE_ENDPOINT: ${{ secrets.JUDGE0_TEST_EXTRA_CE_ENDPOINT }}
35+
JUDGE0_CE_AUTH_HEADERS: ${{ secrets.JUDGE0_CE_AUTH_HEADERS }}
36+
JUDGE0_EXTRA_CE_AUTH_HEADERS: ${{ secrets.JUDGE0_EXTRA_CE_AUTH_HEADERS }}
37+
JUDGE0_CE_ENDPOINT: ${{ secrets.JUDGE0_CE_ENDPOINT }}
38+
JUDGE0_EXTRA_CE_ENDPOINT: ${{ secrets.JUDGE0_EXTRA_CE_ENDPOINT }}
3939
run: |
4040
source venv/bin/activate
4141
pytest -vv tests/

src/judge0/__init__.py

+58-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import os
22

3+
from typing import Union
4+
35
from .api import (
46
async_execute,
57
async_run,
@@ -73,36 +75,25 @@ def _get_implicit_client(flavor: Flavor) -> Client:
7375
if flavor == Flavor.EXTRA_CE and JUDGE0_IMPLICIT_EXTRA_CE_CLIENT is not None:
7476
return JUDGE0_IMPLICIT_EXTRA_CE_CLIENT
7577

76-
from .clients import CE, EXTRA_CE
77-
7878
try:
7979
from dotenv import load_dotenv
8080

8181
load_dotenv()
8282
except: # noqa: E722
8383
pass
8484

85-
if flavor == Flavor.CE:
86-
client_classes = CE
87-
else:
88-
client_classes = EXTRA_CE
85+
# Let's check if we can find a self-hosted client.
86+
client = _get_custom_client(flavor)
8987

9088
# Try to find one of the predefined keys JUDGE0_{SULU,RAPID,ATD}_API_KEY
9189
# in environment variables.
92-
client = None
93-
for client_class in client_classes:
94-
api_key = os.getenv(client_class.API_KEY_ENV)
95-
if api_key is not None:
96-
client = client_class(api_key)
97-
break
90+
if client is None:
91+
client = _get_predefined_client(flavor)
9892

9993
# If we didn't find any of the possible predefined keys, initialize
10094
# the preview Sulu client based on the flavor.
10195
if client is None:
102-
if flavor == Flavor.CE:
103-
client = SuluJudge0CE(retry_strategy=RegularPeriodRetry(0.5))
104-
else:
105-
client = SuluJudge0ExtraCE(retry_strategy=RegularPeriodRetry(0.5))
96+
client = _get_preview_client(flavor)
10697

10798
if flavor == Flavor.CE:
10899
JUDGE0_IMPLICIT_CE_CLIENT = client
@@ -112,6 +103,57 @@ def _get_implicit_client(flavor: Flavor) -> Client:
112103
return client
113104

114105

106+
def _get_preview_client(flavor: Flavor) -> Union[SuluJudge0CE, SuluJudge0ExtraCE]:
107+
if flavor == Flavor.CE:
108+
return SuluJudge0CE(retry_strategy=RegularPeriodRetry(0.5))
109+
else:
110+
return SuluJudge0ExtraCE(retry_strategy=RegularPeriodRetry(0.5))
111+
112+
113+
def _get_custom_client(flavor: Flavor) -> Union[Client, None]:
114+
ce_endpoint = os.getenv("JUDGE0_CE_ENDPOINT")
115+
ce_auth_header = os.getenv("JUDGE0_CE_AUTH_HEADERS")
116+
extra_ce_endpoint = os.getenv("JUDGE0_EXTRA_CE_ENDPOINT")
117+
extra_ce_auth_header = os.getenv("JUDGE0_EXTRA_CE_AUTH_HEADERS")
118+
119+
if flavor == Flavor.CE and ce_endpoint is not None and ce_auth_header is not None:
120+
return Client(
121+
endpoint=ce_endpoint,
122+
auth_headers=eval(ce_auth_header),
123+
)
124+
125+
if (
126+
flavor == Flavor.EXTRA_CE
127+
and extra_ce_endpoint is not None
128+
and extra_ce_auth_header is not None
129+
):
130+
return Client(
131+
endpoint=extra_ce_endpoint,
132+
auth_headers=eval(extra_ce_auth_header),
133+
)
134+
135+
return None
136+
137+
138+
def _get_predefined_client(flavor: Flavor) -> Union[Client, None]:
139+
from .clients import CE, EXTRA_CE
140+
141+
if flavor == Flavor.CE:
142+
client_classes = CE
143+
else:
144+
client_classes = EXTRA_CE
145+
146+
for client_class in client_classes:
147+
api_key = os.getenv(client_class.API_KEY_ENV)
148+
if api_key is not None:
149+
client = client_class(api_key)
150+
break
151+
else:
152+
client = None
153+
154+
return client
155+
156+
115157
CE = Flavor.CE
116158
EXTRA_CE = Flavor.EXTRA_CE
117159

tests/conftest.py

+8-18
Original file line numberDiff line numberDiff line change
@@ -10,34 +10,24 @@
1010

1111
@pytest.fixture(scope="session")
1212
def judge0_ce_client():
13-
api_key = os.getenv("JUDGE0_TEST_API_KEY")
14-
api_key_header = os.getenv("JUDGE0_TEST_API_KEY_HEADER")
15-
endpoint = os.getenv("JUDGE0_TEST_CE_ENDPOINT")
13+
endpoint = os.getenv("JUDGE0_CE_ENDPOINT")
14+
auth_headers = os.getenv("JUDGE0_CE_AUTH_HEADERS")
1615

17-
if api_key is None or api_key_header is None or endpoint is None:
16+
if endpoint is None or auth_headers is None:
1817
return None
1918
else:
20-
client = clients.Client(
21-
endpoint=endpoint,
22-
auth_headers={api_key_header: api_key},
23-
)
24-
return client
19+
return clients.Client(endpoint=endpoint, auth_headers=eval(auth_headers))
2520

2621

2722
@pytest.fixture(scope="session")
2823
def judge0_extra_ce_client():
29-
api_key = os.getenv("JUDGE0_TEST_API_KEY")
30-
api_key_header = os.getenv("JUDGE0_TEST_API_KEY_HEADER")
31-
endpoint = os.getenv("JUDGE0_TEST_EXTRA_CE_ENDPOINT")
24+
endpoint = os.getenv("JUDGE0_EXTRA_CE_ENDPOINT")
25+
auth_headers = os.getenv("JUDGE0_EXTRA_CE_AUTH_HEADERS")
3226

33-
if api_key is None or api_key_header is None or endpoint is None:
27+
if endpoint is None or auth_headers is None:
3428
return None
3529
else:
36-
client = clients.Client(
37-
endpoint=endpoint,
38-
auth_headers={api_key_header: api_key},
39-
)
40-
return client
30+
return clients.Client(endpoint=endpoint, auth_headers=eval(auth_headers))
4131

4232

4333
@pytest.fixture(scope="session")

0 commit comments

Comments
 (0)