|
1 | 1 | import copy
|
2 | 2 | import json
|
3 | 3 | from datetime import datetime, timezone
|
4 |
| -from typing import Dict, List, Optional |
5 |
| -from unittest.mock import Mock, patch |
| 4 | +from typing import Dict, Generator, List, Optional, Tuple, Union |
| 5 | +from unittest.mock import AsyncMock, Mock, patch |
6 | 6 | from uuid import UUID
|
7 | 7 |
|
8 | 8 | import pytest
|
|
14 | 14 | from dstack._internal.core.models.backends.base import BackendType
|
15 | 15 | from dstack._internal.core.models.common import ApplyAction
|
16 | 16 | from dstack._internal.core.models.configurations import ServiceConfiguration
|
| 17 | +from dstack._internal.core.models.gateways import GatewayStatus |
17 | 18 | from dstack._internal.core.models.instances import (
|
18 | 19 | InstanceAvailability,
|
19 | 20 | InstanceOfferWithAvailability,
|
|
43 | 44 | from dstack._internal.server.services.projects import add_project_member
|
44 | 45 | from dstack._internal.server.services.runs import run_model_to_run
|
45 | 46 | from dstack._internal.server.testing.common import (
|
| 47 | + create_backend, |
| 48 | + create_gateway, |
| 49 | + create_gateway_compute, |
46 | 50 | create_job,
|
47 | 51 | create_project,
|
48 | 52 | create_repo,
|
@@ -358,6 +362,32 @@ def get_dev_env_run_dict(
|
358 | 362 | }
|
359 | 363 |
|
360 | 364 |
|
| 365 | +def get_service_run_spec( |
| 366 | + repo_id: str, |
| 367 | + run_name: Optional[str] = None, |
| 368 | + gateway: Optional[Union[bool, str]] = None, |
| 369 | +) -> dict: |
| 370 | + return { |
| 371 | + "configuration": { |
| 372 | + "type": "service", |
| 373 | + "commands": ["python -m http.server"], |
| 374 | + "port": 8000, |
| 375 | + "gateway": gateway, |
| 376 | + "model": "test-model", |
| 377 | + }, |
| 378 | + "configuration_path": "dstack.yaml", |
| 379 | + "profile": { |
| 380 | + "name": "string", |
| 381 | + }, |
| 382 | + "repo_code_hash": None, |
| 383 | + "repo_data": {"repo_dir": "/repo", "repo_type": "local"}, |
| 384 | + "repo_id": repo_id, |
| 385 | + "run_name": run_name, |
| 386 | + "ssh_key_pub": "ssh_key", |
| 387 | + "working_dir": ".", |
| 388 | + } |
| 389 | + |
| 390 | + |
361 | 391 | class TestListRuns:
|
362 | 392 | @pytest.mark.asyncio
|
363 | 393 | @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
|
@@ -1481,3 +1511,144 @@ async def test_backend_does_not_support_create_instance(
|
1481 | 1511 | ]
|
1482 | 1512 | }
|
1483 | 1513 | assert result == expected
|
| 1514 | + |
| 1515 | + |
| 1516 | +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) |
| 1517 | +class TestSubmitService: |
| 1518 | + @pytest.fixture(autouse=True) |
| 1519 | + def mock_gateway_connections(self) -> Generator[None, None, None]: |
| 1520 | + with patch( |
| 1521 | + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" |
| 1522 | + ) as get_conn_mock: |
| 1523 | + get_conn_mock.return_value.client = Mock() |
| 1524 | + get_conn_mock.return_value.client.return_value = AsyncMock() |
| 1525 | + yield |
| 1526 | + |
| 1527 | + @pytest.mark.asyncio |
| 1528 | + @pytest.mark.parametrize( |
| 1529 | + ( |
| 1530 | + "existing_gateways", |
| 1531 | + "specified_gateway_in_run_conf", |
| 1532 | + "expected_service_url", |
| 1533 | + "expected_model_url", |
| 1534 | + ), |
| 1535 | + [ |
| 1536 | + pytest.param( |
| 1537 | + [("default-gateway", True), ("non-default-gateway", False)], |
| 1538 | + None, |
| 1539 | + "https://test-service.default-gateway.example", |
| 1540 | + "https://gateway.default-gateway.example", |
| 1541 | + id="submits-to-default-gateway", |
| 1542 | + ), |
| 1543 | + pytest.param( |
| 1544 | + [("default-gateway", True), ("non-default-gateway", False)], |
| 1545 | + "non-default-gateway", |
| 1546 | + "https://test-service.non-default-gateway.example", |
| 1547 | + "https://gateway.non-default-gateway.example", |
| 1548 | + id="submits-to-specified-gateway", |
| 1549 | + ), |
| 1550 | + pytest.param( |
| 1551 | + [("non-default-gateway", False)], |
| 1552 | + None, |
| 1553 | + "/proxy/services/test-project/test-service/", |
| 1554 | + "/proxy/models/test-project/", |
| 1555 | + id="submits-in-server-when-no-default-gateway", |
| 1556 | + ), |
| 1557 | + pytest.param( |
| 1558 | + [("default-gateway", True)], |
| 1559 | + False, |
| 1560 | + "/proxy/services/test-project/test-service/", |
| 1561 | + "/proxy/models/test-project/", |
| 1562 | + id="submits-in-server-when-specified", |
| 1563 | + ), |
| 1564 | + ], |
| 1565 | + ) |
| 1566 | + async def test_submit_to_correct_proxy( |
| 1567 | + self, |
| 1568 | + test_db, |
| 1569 | + session: AsyncSession, |
| 1570 | + client: AsyncClient, |
| 1571 | + existing_gateways: List[Tuple[str, bool]], |
| 1572 | + specified_gateway_in_run_conf: str, |
| 1573 | + expected_service_url: str, |
| 1574 | + expected_model_url: str, |
| 1575 | + ) -> None: |
| 1576 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 1577 | + project = await create_project(session=session, owner=user, name="test-project") |
| 1578 | + await add_project_member( |
| 1579 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 1580 | + ) |
| 1581 | + repo = await create_repo(session=session, project_id=project.id) |
| 1582 | + backend = await create_backend(session=session, project_id=project.id) |
| 1583 | + for gateway_name, is_default in existing_gateways: |
| 1584 | + gateway_compute = await create_gateway_compute( |
| 1585 | + session=session, |
| 1586 | + backend_id=backend.id, |
| 1587 | + ) |
| 1588 | + gateway = await create_gateway( |
| 1589 | + session=session, |
| 1590 | + project_id=project.id, |
| 1591 | + backend_id=backend.id, |
| 1592 | + gateway_compute_id=gateway_compute.id, |
| 1593 | + status=GatewayStatus.RUNNING, |
| 1594 | + name=gateway_name, |
| 1595 | + wildcard_domain=f"{gateway_name}.example", |
| 1596 | + ) |
| 1597 | + if is_default: |
| 1598 | + project.default_gateway_id = gateway.id |
| 1599 | + await session.commit() |
| 1600 | + run_spec = get_service_run_spec( |
| 1601 | + repo_id=repo.name, |
| 1602 | + run_name="test-service", |
| 1603 | + gateway=specified_gateway_in_run_conf, |
| 1604 | + ) |
| 1605 | + response = await client.post( |
| 1606 | + f"/api/project/{project.name}/runs/submit", |
| 1607 | + headers=get_auth_headers(user.token), |
| 1608 | + json={"run_spec": run_spec}, |
| 1609 | + ) |
| 1610 | + assert response.status_code == 200 |
| 1611 | + assert response.json()["service"]["url"] == expected_service_url |
| 1612 | + assert response.json()["service"]["model"]["base_url"] == expected_model_url |
| 1613 | + |
| 1614 | + @pytest.mark.asyncio |
| 1615 | + async def test_return_error_if_specified_gateway_not_exists( |
| 1616 | + self, test_db, session: AsyncSession, client: AsyncClient |
| 1617 | + ) -> None: |
| 1618 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 1619 | + project = await create_project(session=session, owner=user) |
| 1620 | + await add_project_member( |
| 1621 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 1622 | + ) |
| 1623 | + repo = await create_repo(session=session, project_id=project.id) |
| 1624 | + run_spec = get_service_run_spec(repo_id=repo.name, gateway="nonexistent") |
| 1625 | + response = await client.post( |
| 1626 | + f"/api/project/{project.name}/runs/submit", |
| 1627 | + headers=get_auth_headers(user.token), |
| 1628 | + json={"run_spec": run_spec}, |
| 1629 | + ) |
| 1630 | + assert response.status_code == 400 |
| 1631 | + assert response.json() == { |
| 1632 | + "detail": [ |
| 1633 | + {"msg": "Gateway nonexistent does not exist", "code": "resource_not_exists"} |
| 1634 | + ] |
| 1635 | + } |
| 1636 | + |
| 1637 | + @pytest.mark.asyncio |
| 1638 | + async def test_return_error_if_specified_gateway_is_true( |
| 1639 | + self, test_db, session: AsyncSession, client: AsyncClient |
| 1640 | + ) -> None: |
| 1641 | + user = await create_user(session=session, global_role=GlobalRole.USER) |
| 1642 | + project = await create_project(session=session, owner=user) |
| 1643 | + await add_project_member( |
| 1644 | + session=session, project=project, user=user, project_role=ProjectRole.USER |
| 1645 | + ) |
| 1646 | + repo = await create_repo(session=session, project_id=project.id) |
| 1647 | + run_spec = get_service_run_spec(repo_id=repo.name, gateway=True) |
| 1648 | + response = await client.post( |
| 1649 | + f"/api/project/{project.name}/runs/submit", |
| 1650 | + headers=get_auth_headers(user.token), |
| 1651 | + json={"run_spec": run_spec}, |
| 1652 | + ) |
| 1653 | + assert response.status_code == 422 |
| 1654 | + assert "must be a string or boolean `false`, not boolean `true`" in response.text |
0 commit comments