Skip to content

Commit 3684ef1

Browse files
authored
Gateway-in-server early prototype (#1718)
This commit implements most of the reverse proxying logic for gateway-in-server. It also includes a dependency injection mechanism that will allow the new gateway app to work with different repo (storage) implementations in-server and remotely. For this prototype, gateway-in-server duplicates remote gateways, i.e. all services are available both on a remote gateway and on gateway-in-server. This will be changed later. Behind the GATEWAY_IN_SERVER feature flag.
1 parent 4392ff6 commit 3684ef1

File tree

16 files changed

+732
-2
lines changed

16 files changed

+732
-2
lines changed

requirements_dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ pre-commit
33
httpx>=0.23
44
pytest~=7.2
55
pytest-asyncio>=0.21
6+
pytest-httpbin==2.1.0
67
freezegun>=1.2.0
78
ruff==0.5.3 # Should match .pre-commit-config.yaml
89
testcontainers # testcontainers<4 may not work with asyncpg

src/dstack/_internal/gateway/deps.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from abc import ABC, abstractmethod
2+
from typing import AsyncGenerator
3+
4+
from fastapi import Depends, Request
5+
from typing_extensions import Annotated
6+
7+
from dstack._internal.gateway.repos.base import BaseGatewayRepo
8+
9+
10+
class BaseGatewayDependencyInjector(ABC):
11+
"""
12+
The gateway uses different implementations of this injector in different
13+
environments: in-serer and on a remote host. An instance with the injector interface
14+
stored in FastAPI's app.state.gateway_dependency_injector configures the gateway to
15+
use a specific set of dependencies, e.g. a specific repo implementation.
16+
"""
17+
18+
@abstractmethod
19+
async def get_repo(self) -> AsyncGenerator[BaseGatewayRepo, None]:
20+
if False:
21+
yield # show type checkers this is a generator
22+
23+
24+
async def get_injector(request: Request) -> BaseGatewayDependencyInjector:
25+
injector = request.app.state.gateway_dependency_injector
26+
if not isinstance(injector, BaseGatewayDependencyInjector):
27+
raise RuntimeError(f"Wrong BaseGatewayDependencyInjector type {type(injector)}")
28+
return injector
29+
30+
31+
async def get_gateway_repo(
32+
injector: Annotated[BaseGatewayDependencyInjector, Depends(get_injector)],
33+
) -> AsyncGenerator[BaseGatewayRepo, None]:
34+
async for repo in injector.get_repo():
35+
yield repo

src/dstack/_internal/gateway/repos/__init__.py

Whitespace-only changes.
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from abc import ABC, abstractmethod
2+
from typing import List, Optional
3+
4+
from pydantic import BaseModel
5+
6+
from dstack._internal.core.models.instances import SSHConnectionParams
7+
8+
9+
class Replica(BaseModel):
10+
id: str
11+
ssh_destination: str
12+
ssh_port: int
13+
ssh_proxy: Optional[SSHConnectionParams]
14+
15+
16+
class Service(BaseModel):
17+
id: str
18+
run_name: str
19+
auth: bool
20+
app_port: int
21+
replicas: List[Replica]
22+
23+
24+
class Project(BaseModel):
25+
name: str
26+
ssh_private_key: str
27+
28+
29+
class BaseGatewayRepo(ABC):
30+
@abstractmethod
31+
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
32+
pass
33+
34+
@abstractmethod
35+
async def add_service(self, project_name: str, service: Service) -> None:
36+
pass
37+
38+
@abstractmethod
39+
async def get_project(self, name: str) -> Optional[Project]:
40+
pass
41+
42+
@abstractmethod
43+
async def add_project(self, project: Project) -> None:
44+
pass
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Dict, Optional
2+
3+
from dstack._internal.gateway.repos.base import BaseGatewayRepo, Project, Service
4+
5+
6+
class InMemoryGatewayRepo(BaseGatewayRepo):
7+
def __init__(self) -> None:
8+
self.services: Dict[str, Dict[str, Service]] = {}
9+
self.projects: Dict[str, Project] = {}
10+
11+
async def get_service(self, project_name: str, run_name: str) -> Optional[Service]:
12+
return self.services.get(project_name, {}).get(run_name)
13+
14+
async def add_service(self, project_name: str, service: Service) -> None:
15+
self.services.setdefault(project_name, {})[service.run_name] = service
16+
17+
async def get_project(self, name: str) -> Optional[Project]:
18+
return self.projects.get(name)
19+
20+
async def add_project(self, project: Project) -> None:
21+
self.projects[project.name] = project

src/dstack/_internal/gateway/routers/__init__.py

Whitespace-only changes.
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from fastapi import APIRouter, Depends, Request, status
2+
from fastapi.datastructures import URL
3+
from fastapi.responses import RedirectResponse, Response
4+
from typing_extensions import Annotated
5+
6+
from dstack._internal.gateway.deps import get_gateway_repo
7+
from dstack._internal.gateway.repos.base import BaseGatewayRepo
8+
from dstack._internal.gateway.services import service_proxy
9+
10+
REDIRECTED_HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"]
11+
PROXIED_HTTP_METHODS = REDIRECTED_HTTP_METHODS + ["OPTIONS"]
12+
13+
14+
router = APIRouter()
15+
16+
17+
@router.api_route("/{project_name}/{run_name}", methods=REDIRECTED_HTTP_METHODS)
18+
async def redirect_to_service_root(request: Request) -> Response:
19+
url = URL(str(request.url))
20+
url = url.replace(path=url.path + "/")
21+
return RedirectResponse(url, status.HTTP_308_PERMANENT_REDIRECT)
22+
23+
24+
@router.api_route("/{project_name}/{run_name}/{path:path}", methods=PROXIED_HTTP_METHODS)
25+
async def service_reverse_proxy(
26+
project_name: str,
27+
run_name: str,
28+
path: str,
29+
request: Request,
30+
repo: Annotated[BaseGatewayRepo, Depends(get_gateway_repo)],
31+
) -> Response:
32+
return await service_proxy.proxy(project_name, run_name, path, request, repo)
33+
34+
35+
# TODO(#1595): support websockets

src/dstack/_internal/gateway/services/__init__.py

Whitespace-only changes.
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import asyncio
2+
from pathlib import Path
3+
from tempfile import TemporaryDirectory
4+
from typing import Dict, Optional
5+
6+
from httpx import AsyncClient, AsyncHTTPTransport
7+
8+
from dstack._internal.core.services.ssh.tunnel import (
9+
SSH_DEFAULT_OPTIONS,
10+
IPSocket,
11+
SocketPair,
12+
SSHTunnel,
13+
UnixSocket,
14+
)
15+
from dstack._internal.gateway.repos.base import Project, Replica, Service
16+
from dstack._internal.utils.logging import get_logger
17+
from dstack._internal.utils.path import FileContent
18+
19+
logger = get_logger(__name__)
20+
OPEN_TUNNEL_TIMEOUT = 10
21+
22+
23+
class ServiceReplicaConnection:
24+
def __init__(self, project: Project, service: Service, replica: Replica) -> None:
25+
self._temp_dir = TemporaryDirectory()
26+
app_socket_path = (Path(self._temp_dir.name) / "replica.sock").absolute()
27+
self._tunnel = SSHTunnel(
28+
destination=replica.ssh_destination,
29+
port=replica.ssh_port,
30+
ssh_proxy=replica.ssh_proxy,
31+
identity=FileContent(project.ssh_private_key),
32+
forwarded_sockets=[
33+
SocketPair(
34+
remote=IPSocket("localhost", service.app_port),
35+
local=UnixSocket(app_socket_path),
36+
),
37+
],
38+
options={
39+
**SSH_DEFAULT_OPTIONS,
40+
"ConnectTimeout": str(OPEN_TUNNEL_TIMEOUT),
41+
},
42+
)
43+
self._client = AsyncClient(
44+
transport=AsyncHTTPTransport(uds=str(app_socket_path)),
45+
# The hostname in base_url is normally a placeholder, it will be overwritten
46+
# by proxied requests' Host header unless they don't have it (HTTP/1.0)
47+
base_url="http://service/",
48+
)
49+
self._is_open = asyncio.locks.Event()
50+
51+
async def open(self) -> None:
52+
await self._tunnel.aopen()
53+
self._is_open.set()
54+
55+
async def close(self) -> None:
56+
self._is_open.clear()
57+
await self._client.aclose()
58+
await self._tunnel.aclose()
59+
60+
async def client(self) -> AsyncClient:
61+
await asyncio.wait_for(self._is_open.wait(), timeout=OPEN_TUNNEL_TIMEOUT)
62+
return self._client
63+
64+
65+
class ServiceReplicaConnectionPool:
66+
def __init__(self) -> None:
67+
# TODO(#1595): remove connections to stopped replicas
68+
self.connections: Dict[str, ServiceReplicaConnection] = {}
69+
70+
async def get(self, replica_id: str) -> Optional[ServiceReplicaConnection]:
71+
return self.connections.get(replica_id)
72+
73+
async def add(
74+
self, project: Project, service: Service, replica: Replica
75+
) -> ServiceReplicaConnection:
76+
connection = self.connections.get(replica.id)
77+
if connection is not None:
78+
return connection
79+
connection = ServiceReplicaConnection(project, service, replica)
80+
self.connections[replica.id] = connection
81+
try:
82+
await connection.open()
83+
except BaseException:
84+
self.connections.pop(replica.id, None)
85+
raise
86+
return connection
87+
88+
async def remove(self, replica_id: str) -> None:
89+
connection = self.connections.pop(replica_id, None)
90+
if connection is not None:
91+
await connection.close()
92+
93+
async def remove_all(self) -> None:
94+
replica_ids = list(self.connections)
95+
results = await asyncio.gather(
96+
*(self.remove(replica_id) for replica_id in replica_ids), return_exceptions=True
97+
)
98+
for i, exc in enumerate(results):
99+
if isinstance(exc, Exception):
100+
logger.error(
101+
"Error removing connection to service replica %s: %s", replica_ids[i], exc
102+
)
103+
104+
105+
service_replica_connection_pool: ServiceReplicaConnectionPool = ServiceReplicaConnectionPool()

0 commit comments

Comments
 (0)