Skip to content

Commit fc5286d

Browse files
authored
Allow specifying gateway in service configurations (#1972)
1 parent 8eb9b15 commit fc5286d

File tree

4 files changed

+214
-13
lines changed

4 files changed

+214
-13
lines changed

src/dstack/_internal/core/models/configurations.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ class ServiceConfigurationParams(CoreModel):
209209
Union[ValidPort, constr(regex=r"^[0-9]+:[0-9]+$"), PortMapping],
210210
Field(description="The port, that application listens on or the mapping"),
211211
]
212+
gateway: Annotated[
213+
Optional[Union[bool, str]],
214+
Field(
215+
description=(
216+
"The name of the gateway. Specify boolean `false` to run without a gateway."
217+
" Omit to run with the default gateway"
218+
),
219+
),
220+
] = None
212221
model: Annotated[
213222
Optional[Union[AnyModel, str]],
214223
Field(
@@ -267,6 +276,16 @@ def convert_replicas(cls, v: Any) -> Range[int]:
267276
)
268277
return v
269278

279+
@validator("gateway")
280+
def validate_gateway(
281+
cls, v: Optional[Union[bool, str]]
282+
) -> Optional[Union[Literal[False], str]]:
283+
if v == True:
284+
raise ValueError(
285+
"The `gateway` property must be a string or boolean `false`, not boolean `true`"
286+
)
287+
return v
288+
270289
@root_validator()
271290
def validate_scaling(cls, values):
272291
scaling = values.get("scaling")

src/dstack/_internal/server/services/gateways/__init__.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
SSHError,
2929
)
3030
from dstack._internal.core.models.backends.base import BackendType
31+
from dstack._internal.core.models.common import is_core_model_instance
3132
from dstack._internal.core.models.configurations import SERVICE_HTTPS_DEFAULT, ServiceConfiguration
3233
from dstack._internal.core.models.gateways import (
3334
Gateway,
@@ -354,27 +355,38 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
354355
return name
355356

356357

357-
async def register_service(session: AsyncSession, run_model: RunModel):
358-
gateway = run_model.project.default_gateway
358+
async def register_service(session: AsyncSession, run_model: RunModel, run_spec: RunSpec):
359+
assert is_core_model_instance(run_spec.configuration, ServiceConfiguration)
360+
361+
if isinstance(run_spec.configuration.gateway, str):
362+
gateway = await get_project_gateway_model_by_name(
363+
session=session, project=run_model.project, name=run_spec.configuration.gateway
364+
)
365+
if gateway is None:
366+
raise ResourceNotExistsError(
367+
f"Gateway {run_spec.configuration.gateway} does not exist"
368+
)
369+
elif run_spec.configuration.gateway == False:
370+
gateway = None
371+
else:
372+
gateway = run_model.project.default_gateway
359373

360374
if gateway is not None:
361-
service_spec = await _register_service_in_gateway(session, run_model, gateway)
375+
service_spec = await _register_service_in_gateway(session, run_model, run_spec, gateway)
362376
run_model.gateway = gateway
363377
elif not settings.FORBID_SERVICES_WITHOUT_GATEWAY:
364-
service_spec = _register_service_in_server(run_model)
378+
service_spec = _register_service_in_server(run_model, run_spec)
365379
else:
366380
raise ResourceNotExistsError(
367381
"This dstack-server installation forbids services without a gateway."
368-
" Please configure a default gateway."
382+
" Please configure a gateway."
369383
)
370384
run_model.service_spec = service_spec.json()
371385

372386

373387
async def _register_service_in_gateway(
374-
session: AsyncSession, run_model: RunModel, gateway: GatewayModel
388+
session: AsyncSession, run_model: RunModel, run_spec: RunSpec, gateway: GatewayModel
375389
) -> ServiceSpec:
376-
run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec)
377-
378390
if gateway.gateway_compute is None:
379391
raise ServerClientError("Gateway has no instance associated with it")
380392

@@ -427,8 +439,7 @@ async def _register_service_in_gateway(
427439
return service_spec
428440

429441

430-
def _register_service_in_server(run_model: RunModel) -> ServiceSpec:
431-
run_spec: RunSpec = RunSpec.__response__.parse_raw(run_model.run_spec)
442+
def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> ServiceSpec:
432443
if run_spec.configuration.https != SERVICE_HTTPS_DEFAULT:
433444
# Note: if the user sets `https: <default-value>`, it will be ignored silently
434445
# TODO: in 0.19, make `https` Optional to be able to tell if it was set or omitted

src/dstack/_internal/server/services/runs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ async def submit_run(
431431
replicas = 1
432432
if run_spec.configuration.type == "service":
433433
replicas = run_spec.configuration.replicas.min
434-
await gateways.register_service(session, run_model)
434+
await gateways.register_service(session, run_model, run_spec)
435435

436436
for replica_num in range(replicas):
437437
jobs = await get_jobs_from_run_spec(run_spec, replica_num=replica_num)

src/tests/_internal/server/routers/test_runs.py

Lines changed: 173 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import copy
22
import json
33
from datetime import datetime, timezone
4-
from typing import Dict, List, Optional
5-
from unittest.mock import Mock, patch
4+
from typing import Dict, Generator, List, Optional, Tuple, Union
5+
from unittest.mock import AsyncMock, Mock, patch
66
from uuid import UUID
77

88
import pytest
@@ -14,6 +14,7 @@
1414
from dstack._internal.core.models.backends.base import BackendType
1515
from dstack._internal.core.models.common import ApplyAction
1616
from dstack._internal.core.models.configurations import ServiceConfiguration
17+
from dstack._internal.core.models.gateways import GatewayStatus
1718
from dstack._internal.core.models.instances import (
1819
InstanceAvailability,
1920
InstanceOfferWithAvailability,
@@ -43,6 +44,9 @@
4344
from dstack._internal.server.services.projects import add_project_member
4445
from dstack._internal.server.services.runs import run_model_to_run
4546
from dstack._internal.server.testing.common import (
47+
create_backend,
48+
create_gateway,
49+
create_gateway_compute,
4650
create_job,
4751
create_project,
4852
create_repo,
@@ -358,6 +362,32 @@ def get_dev_env_run_dict(
358362
}
359363

360364

365+
def get_service_run_spec(
366+
repo_id: str,
367+
run_name: Optional[str] = None,
368+
gateway: Optional[Union[bool, str]] = None,
369+
) -> dict:
370+
return {
371+
"configuration": {
372+
"type": "service",
373+
"commands": ["python -m http.server"],
374+
"port": 8000,
375+
"gateway": gateway,
376+
"model": "test-model",
377+
},
378+
"configuration_path": "dstack.yaml",
379+
"profile": {
380+
"name": "string",
381+
},
382+
"repo_code_hash": None,
383+
"repo_data": {"repo_dir": "/repo", "repo_type": "local"},
384+
"repo_id": repo_id,
385+
"run_name": run_name,
386+
"ssh_key_pub": "ssh_key",
387+
"working_dir": ".",
388+
}
389+
390+
361391
class TestListRuns:
362392
@pytest.mark.asyncio
363393
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
@@ -1481,3 +1511,144 @@ async def test_backend_does_not_support_create_instance(
14811511
]
14821512
}
14831513
assert result == expected
1514+
1515+
1516+
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
1517+
class TestSubmitService:
1518+
@pytest.fixture(autouse=True)
1519+
def mock_gateway_connections(self) -> Generator[None, None, None]:
1520+
with patch(
1521+
"dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add"
1522+
) as get_conn_mock:
1523+
get_conn_mock.return_value.client = Mock()
1524+
get_conn_mock.return_value.client.return_value = AsyncMock()
1525+
yield
1526+
1527+
@pytest.mark.asyncio
1528+
@pytest.mark.parametrize(
1529+
(
1530+
"existing_gateways",
1531+
"specified_gateway_in_run_conf",
1532+
"expected_service_url",
1533+
"expected_model_url",
1534+
),
1535+
[
1536+
pytest.param(
1537+
[("default-gateway", True), ("non-default-gateway", False)],
1538+
None,
1539+
"https://test-service.default-gateway.example",
1540+
"https://gateway.default-gateway.example",
1541+
id="submits-to-default-gateway",
1542+
),
1543+
pytest.param(
1544+
[("default-gateway", True), ("non-default-gateway", False)],
1545+
"non-default-gateway",
1546+
"https://test-service.non-default-gateway.example",
1547+
"https://gateway.non-default-gateway.example",
1548+
id="submits-to-specified-gateway",
1549+
),
1550+
pytest.param(
1551+
[("non-default-gateway", False)],
1552+
None,
1553+
"/proxy/services/test-project/test-service/",
1554+
"/proxy/models/test-project/",
1555+
id="submits-in-server-when-no-default-gateway",
1556+
),
1557+
pytest.param(
1558+
[("default-gateway", True)],
1559+
False,
1560+
"/proxy/services/test-project/test-service/",
1561+
"/proxy/models/test-project/",
1562+
id="submits-in-server-when-specified",
1563+
),
1564+
],
1565+
)
1566+
async def test_submit_to_correct_proxy(
1567+
self,
1568+
test_db,
1569+
session: AsyncSession,
1570+
client: AsyncClient,
1571+
existing_gateways: List[Tuple[str, bool]],
1572+
specified_gateway_in_run_conf: str,
1573+
expected_service_url: str,
1574+
expected_model_url: str,
1575+
) -> None:
1576+
user = await create_user(session=session, global_role=GlobalRole.USER)
1577+
project = await create_project(session=session, owner=user, name="test-project")
1578+
await add_project_member(
1579+
session=session, project=project, user=user, project_role=ProjectRole.USER
1580+
)
1581+
repo = await create_repo(session=session, project_id=project.id)
1582+
backend = await create_backend(session=session, project_id=project.id)
1583+
for gateway_name, is_default in existing_gateways:
1584+
gateway_compute = await create_gateway_compute(
1585+
session=session,
1586+
backend_id=backend.id,
1587+
)
1588+
gateway = await create_gateway(
1589+
session=session,
1590+
project_id=project.id,
1591+
backend_id=backend.id,
1592+
gateway_compute_id=gateway_compute.id,
1593+
status=GatewayStatus.RUNNING,
1594+
name=gateway_name,
1595+
wildcard_domain=f"{gateway_name}.example",
1596+
)
1597+
if is_default:
1598+
project.default_gateway_id = gateway.id
1599+
await session.commit()
1600+
run_spec = get_service_run_spec(
1601+
repo_id=repo.name,
1602+
run_name="test-service",
1603+
gateway=specified_gateway_in_run_conf,
1604+
)
1605+
response = await client.post(
1606+
f"/api/project/{project.name}/runs/submit",
1607+
headers=get_auth_headers(user.token),
1608+
json={"run_spec": run_spec},
1609+
)
1610+
assert response.status_code == 200
1611+
assert response.json()["service"]["url"] == expected_service_url
1612+
assert response.json()["service"]["model"]["base_url"] == expected_model_url
1613+
1614+
@pytest.mark.asyncio
1615+
async def test_return_error_if_specified_gateway_not_exists(
1616+
self, test_db, session: AsyncSession, client: AsyncClient
1617+
) -> None:
1618+
user = await create_user(session=session, global_role=GlobalRole.USER)
1619+
project = await create_project(session=session, owner=user)
1620+
await add_project_member(
1621+
session=session, project=project, user=user, project_role=ProjectRole.USER
1622+
)
1623+
repo = await create_repo(session=session, project_id=project.id)
1624+
run_spec = get_service_run_spec(repo_id=repo.name, gateway="nonexistent")
1625+
response = await client.post(
1626+
f"/api/project/{project.name}/runs/submit",
1627+
headers=get_auth_headers(user.token),
1628+
json={"run_spec": run_spec},
1629+
)
1630+
assert response.status_code == 400
1631+
assert response.json() == {
1632+
"detail": [
1633+
{"msg": "Gateway nonexistent does not exist", "code": "resource_not_exists"}
1634+
]
1635+
}
1636+
1637+
@pytest.mark.asyncio
1638+
async def test_return_error_if_specified_gateway_is_true(
1639+
self, test_db, session: AsyncSession, client: AsyncClient
1640+
) -> None:
1641+
user = await create_user(session=session, global_role=GlobalRole.USER)
1642+
project = await create_project(session=session, owner=user)
1643+
await add_project_member(
1644+
session=session, project=project, user=user, project_role=ProjectRole.USER
1645+
)
1646+
repo = await create_repo(session=session, project_id=project.id)
1647+
run_spec = get_service_run_spec(repo_id=repo.name, gateway=True)
1648+
response = await client.post(
1649+
f"/api/project/{project.name}/runs/submit",
1650+
headers=get_auth_headers(user.token),
1651+
json={"run_spec": run_spec},
1652+
)
1653+
assert response.status_code == 422
1654+
assert "must be a string or boolean `false`, not boolean `true`" in response.text

0 commit comments

Comments
 (0)