Skip to content

Commit 4cd5911

Browse files
committed
Add tests and other fixes
1 parent 363b308 commit 4cd5911

File tree

7 files changed

+263
-14
lines changed

7 files changed

+263
-14
lines changed

requirements_dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pre-commit
33
httpx>=0.23
44
pytest~=7.2
55
pytest-asyncio>=0.21
6+
pytest-httpbin==2.1.0
67
freezegun>=1.2.0
78
ruff==0.5.3 # Should match .pre-commit-config.yaml
89
testcontainers # testcontainers<4 may not work with asyncpg

src/dstack/_internal/gateway/deps.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88

99

1010
class BaseGatewayDependencyInjector(ABC):
11+
"""
12+
The gateway uses different implementations of this injector in different
13+
environments: in-serer and on a remote host. An instance with the injector interface
14+
stored in FastAPI's app.state.gateway_dependency_injector configures the gateway to
15+
use a specific set of dependencies, e.g. a specific repo implementation.
16+
"""
17+
1118
@abstractmethod
1219
async def get_repo(self) -> AsyncGenerator[BaseGatewayRepo, None]:
1320
if False:

src/dstack/_internal/gateway/repos/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,17 @@ class Project(BaseModel):
2828

2929
class BaseGatewayRepo(ABC):
3030
@abstractmethod
31-
async def get_service(self, project_name: str, name: str) -> Optional[Service]:
31+
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
32+
pass
33+
34+
@abstractmethod
35+
async def add_service(self, project_name: str, service: Service) -> None:
3236
pass
3337

3438
@abstractmethod
3539
async def get_project(self, name: str) -> Optional[Project]:
3640
pass
41+
42+
@abstractmethod
43+
async def add_project(self, project: Project) -> None:
44+
pass
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict, Optional
2+
3+
from dstack._internal.gateway.repos.base import BaseGatewayRepo, Project, Service
4+
5+
6+
class InMemoryGatewayRepo(BaseGatewayRepo):
7+
def __init__(self) -> None:
8+
self.services: Dict[str, Dict[str, Service]] = {}
9+
self.projects: Dict[str, Project] = {}
10+
11+
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
12+
return self.services.get(project_name, {}).get(run_name)
13+
14+
async def add_service(self, project_name: str, service: Service) -> None:
15+
self.services.setdefault(project_name, {})[service.run_name] = service
16+
17+
async def get_project(self, name: str) -> Optional[Project]:
18+
return self.projects.get(name)
19+
20+
async def add_project(self, project: Project) -> None:
21+
self.projects[project.name] = project

src/dstack/_internal/gateway/services/service_proxy.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import httpx
66
from starlette.requests import ClientDisconnect
77

8-
from dstack._internal.gateway.repos.base import BaseGatewayRepo
8+
from dstack._internal.gateway.repos.base import BaseGatewayRepo, Replica, Service
99
from dstack._internal.gateway.services.service_connection import service_replica_connection_pool
1010
from dstack._internal.utils.logging import get_logger
1111

@@ -38,15 +38,7 @@ async def proxy(
3838
)
3939

4040
replica = random.choice(service.replicas)
41-
42-
connection = await service_replica_connection_pool.get(replica.id)
43-
if connection is None:
44-
project = await repo.get_project(project_name)
45-
if project is None:
46-
raise RuntimeError(f"Expected to find project {project_name} but could not")
47-
connection = await service_replica_connection_pool.add(project, service, replica)
48-
49-
client = await connection.client()
41+
client = await get_replica_client(project_name, service, replica, repo)
5042

5143
try:
5244
upstream_request = await build_upstream_request(request, path, client, replica.id)
@@ -68,7 +60,7 @@ async def proxy(
6860
replica.id,
6961
e,
7062
)
71-
if isinstance(e, TimeoutError):
63+
if isinstance(e, httpx.TimeoutException):
7264
raise fastapi.HTTPException(fastapi.status.HTTP_504_GATEWAY_TIMEOUT)
7365
raise fastapi.HTTPException(fastapi.status.HTTP_502_BAD_GATEWAY)
7466

@@ -79,6 +71,18 @@ async def proxy(
7971
)
8072

8173

74+
async def get_replica_client(
75+
project_name: str, service: Service, replica: Replica, repo: BaseGatewayRepo
76+
) -> httpx.AsyncClient:
77+
connection = await service_replica_connection_pool.get(replica.id)
78+
if connection is None:
79+
project = await repo.get_project(project_name)
80+
if project is None:
81+
raise RuntimeError(f"Expected to find project {project_name} but could not")
82+
connection = await service_replica_connection_pool.add(project, service, replica)
83+
return await connection.client()
84+
85+
8286
async def stream_response(
8387
response: httpx.Response, replica_id: str
8488
) -> AsyncGenerator[bytes, None]:

src/dstack/_internal/server/services/gateway_in_server/repo.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,22 @@
1313

1414

1515
class DBGatewayRepo(BaseGatewayRepo):
16+
"""
17+
A gateway repo implementation used for gateway-in-server that retrieves data from
18+
dstack-server's database. Since the database is populated by dstack-server, all or
19+
most writer methods in this implementation are expected to be empty.
20+
"""
21+
1622
def __init__(self, session: AsyncSession) -> None:
1723
self.session = session
1824

19-
async def get_service(self, project_name: str, name: str) -> Optional[Service]:
25+
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
2026
res = await self.session.execute(
2127
select(JobModel)
2228
.join(JobModel.project)
2329
.where(
2430
ProjectModel.name == project_name,
25-
JobModel.run_name == name,
31+
JobModel.run_name == run_name,
2632
JobModel.status == JobStatus.RUNNING,
2733
JobModel.job_num == 0,
2834
)
@@ -67,6 +73,9 @@ async def get_service(self, project_name: str, name: str) -> Optional[Service]:
6773
replicas=replicas,
6874
)
6975

76+
async def add_service(self, project_name: str, service: Service) -> None:
77+
pass
78+
7079
async def get_project(self, name: str) -> Optional[Project]:
7180
res = await self.session.execute(select(ProjectModel).where(ProjectModel.name == name))
7281
project = res.scalar_one_or_none()
@@ -76,3 +85,6 @@ async def get_project(self, name: str) -> Optional[Project]:
7685
name=project.name,
7786
ssh_private_key=project.ssh_private_key,
7887
)
88+
89+
async def add_project(self, project: Project) -> None:
90+
pass
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from typing import AsyncGenerator, Generator, Tuple
2+
from unittest.mock import patch
3+
4+
import httpx
5+
import pytest
6+
from fastapi import FastAPI
7+
8+
from dstack._internal.gateway.deps import BaseGatewayDependencyInjector
9+
from dstack._internal.gateway.repos.base import BaseGatewayRepo, Project, Replica, Service
10+
from dstack._internal.gateway.repos.memory import InMemoryGatewayRepo
11+
from dstack._internal.gateway.routers.service_proxy import router
12+
13+
14+
def make_app(repo: BaseGatewayRepo) -> FastAPI:
15+
class DependencyInjector(BaseGatewayDependencyInjector):
16+
async def get_repo(self) -> AsyncGenerator[BaseGatewayRepo, None]:
17+
yield repo
18+
19+
app = FastAPI()
20+
app.state.gateway_dependency_injector = DependencyInjector()
21+
app.include_router(router, prefix="/gateway/services")
22+
return app
23+
24+
25+
def make_client(app: FastAPI) -> httpx.AsyncClient:
26+
return httpx.AsyncClient(transport=httpx.ASGITransport(app=app))
27+
28+
29+
def make_app_client(repo: BaseGatewayRepo) -> Tuple[FastAPI, httpx.AsyncClient]:
30+
app = make_app(repo)
31+
client = make_client(app)
32+
return app, client
33+
34+
35+
def make_project(name: str) -> Project:
36+
return Project(name=name, ssh_private_key="secret")
37+
38+
39+
def make_service(run_name: str) -> Service:
40+
return Service(
41+
id="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
42+
run_name=run_name,
43+
auth=False,
44+
app_port=80,
45+
replicas=[
46+
Replica(
47+
id="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
48+
ssh_destination="ubuntu@server",
49+
ssh_port=22,
50+
ssh_proxy=None,
51+
)
52+
],
53+
)
54+
55+
56+
MOCK_REPLICA_CLIENT_TIMEOUT = 8
57+
58+
59+
@pytest.fixture
60+
def mock_replica_client_httpbin(httpbin) -> Generator[None, None, None]:
61+
with patch(
62+
"dstack._internal.gateway.services.service_proxy.get_replica_client"
63+
) as get_replica_client_mock:
64+
get_replica_client_mock.return_value = httpx.AsyncClient(
65+
base_url=httpbin.url, timeout=MOCK_REPLICA_CLIENT_TIMEOUT
66+
)
67+
yield
68+
69+
70+
@pytest.mark.asyncio
71+
@pytest.mark.parametrize("method", ["get", "post", "put", "patch", "delete"])
72+
async def test_proxy(mock_replica_client_httpbin, method: str) -> None:
73+
methods_without_body = "get", "delete"
74+
repo = InMemoryGatewayRepo()
75+
await repo.add_project(make_project("test-proj"))
76+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
77+
_, client = make_app_client(repo)
78+
req_body = "." * 20 * 2**20 if method not in methods_without_body else None
79+
resp = await client.request(
80+
method,
81+
f"http://test-host:8888/gateway/services/test-proj/httpbin/{method}?a=b&c=",
82+
headers={"User-Agent": "test-ua", "Connection": "keep-alive"},
83+
content=req_body,
84+
)
85+
assert resp.status_code == 200
86+
assert resp.headers["server"].startswith("Pytest-HTTPBIN")
87+
resp_body = resp.json()
88+
assert resp_body["url"] == f"http://test-host:8888/{method}?a=b&c="
89+
assert resp_body["args"] == {"a": "b", "c": ""}
90+
assert resp_body["headers"]["Host"] == "test-host:8888"
91+
assert resp_body["headers"]["User-Agent"] == "test-ua"
92+
assert resp_body["headers"]["Connection"] == "keep-alive"
93+
if method not in methods_without_body:
94+
assert resp_body["data"] == req_body
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_proxy_method_head(mock_replica_client_httpbin) -> None:
99+
repo = InMemoryGatewayRepo()
100+
await repo.add_project(make_project("test-proj"))
101+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
102+
_, client = make_app_client(repo)
103+
url = "http://test-host/gateway/services/test-proj/httpbin/"
104+
get_resp = await client.get(url)
105+
head_resp = await client.head(url)
106+
assert get_resp.status_code == head_resp.status_code == 200
107+
assert head_resp.headers["Content-Length"] == get_resp.headers["Content-Length"]
108+
assert int(head_resp.headers["Content-Length"]) > 0
109+
assert head_resp.content == b""
110+
111+
112+
@pytest.mark.asyncio
113+
async def test_proxy_method_options(mock_replica_client_httpbin) -> None:
114+
repo = InMemoryGatewayRepo()
115+
await repo.add_project(make_project("test-proj"))
116+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
117+
_, client = make_app_client(repo)
118+
resp = await client.options("http://test-host/gateway/services/test-proj/httpbin/get")
119+
assert resp.status_code == 200
120+
assert set(resp.headers["Allow"].split(", ")) == {"HEAD", "GET", "OPTIONS"}
121+
assert resp.content == b""
122+
123+
124+
@pytest.mark.asyncio
125+
@pytest.mark.parametrize("code", [204, 304, 418, 503])
126+
async def test_proxy_status_codes(mock_replica_client_httpbin, code: int) -> None:
127+
repo = InMemoryGatewayRepo()
128+
await repo.add_project(make_project("test-proj"))
129+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
130+
_, client = make_app_client(repo)
131+
resp = await client.get(f"http://test-host/gateway/services/test-proj/httpbin/status/{code}")
132+
assert resp.status_code == code
133+
134+
135+
@pytest.mark.asyncio
136+
async def test_proxy_not_leaks_cookies(mock_replica_client_httpbin) -> None:
137+
repo = InMemoryGatewayRepo()
138+
await repo.add_project(make_project("test-proj"))
139+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
140+
app = make_app(repo)
141+
client1 = make_client(app)
142+
client2 = make_client(app)
143+
cookies_url = "http://test-host/gateway/services/test-proj/httpbin/cookies"
144+
await client1.get(cookies_url + "/set?a=1")
145+
await client1.get(cookies_url + "/set?b=2")
146+
await client2.get(cookies_url + "/set?a=3")
147+
resp1 = await client1.get(cookies_url)
148+
resp2 = await client2.get(cookies_url)
149+
assert resp1.json()["cookies"] == {"a": "1", "b": "2"}
150+
assert resp2.json()["cookies"] == {"a": "3"}
151+
152+
153+
@pytest.mark.asyncio
154+
async def test_proxy_gateway_timeout(mock_replica_client_httpbin) -> None:
155+
repo = InMemoryGatewayRepo()
156+
await repo.add_project(make_project("test-proj"))
157+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
158+
_, client = make_app_client(repo)
159+
assert MOCK_REPLICA_CLIENT_TIMEOUT < 10
160+
resp = await client.get("http://test-host/gateway/services/test-proj/httpbin/delay/10")
161+
assert resp.status_code == 504
162+
assert resp.json()["detail"] == "Gateway Timeout"
163+
164+
165+
@pytest.mark.asyncio
166+
async def test_proxy_run_not_found(mock_replica_client_httpbin) -> None:
167+
repo = InMemoryGatewayRepo()
168+
await repo.add_project(make_project("test-proj"))
169+
await repo.add_service(project_name="test-proj", service=make_service("test-run"))
170+
_, client = make_app_client(repo)
171+
resp = await client.get("http://test-host/gateway/services/test-proj/unknown/")
172+
assert resp.status_code == 404
173+
assert resp.json()["detail"] == "Service test-proj/unknown not found"
174+
175+
176+
@pytest.mark.asyncio
177+
async def test_proxy_project_not_found(mock_replica_client_httpbin) -> None:
178+
_, client = make_app_client(InMemoryGatewayRepo())
179+
resp = await client.get("http://test-host/gateway/services/unknown/test-run/")
180+
assert resp.status_code == 404
181+
assert resp.json()["detail"] == "Service unknown/test-run not found"
182+
183+
184+
@pytest.mark.asyncio
185+
async def test_redirect_to_service_root(mock_replica_client_httpbin) -> None:
186+
repo = InMemoryGatewayRepo()
187+
await repo.add_project(make_project("test-proj"))
188+
await repo.add_service(project_name="test-proj", service=make_service("httpbin"))
189+
_, client = make_app_client(repo)
190+
url = "http://test-host/gateway/services/test-proj/httpbin"
191+
resp = await client.get(url, follow_redirects=False)
192+
assert resp.status_code == 308
193+
assert resp.headers["Location"] == url + "/"
194+
resp = await client.get(url, follow_redirects=True)
195+
assert resp.status_code == 200
196+
assert resp.request.url == url + "/"

0 commit comments

Comments
 (0)