Skip to content

Commit 3d5ab80

Browse files
authored
Add REST plugin for user-defined policies (#2631)
* Add REST plugin * Change rest-plugin to a builtin plugin * Add plugin server example * Document rest-plugin in guides * Refactor rest-plugin + polish models * Restructure rest_plugin modules * Doc updates * Additional type checks, type hints and field descriptions * Unskip postgres tests
1 parent 4daf142 commit 3d5ab80

File tree

16 files changed

+657
-58
lines changed

16 files changed

+657
-58
lines changed

docs/docs/guides/plugins.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,4 +113,14 @@ class ExamplePolicy(ApplyPolicy):
113113

114114
</div>
115115

116-
For more information on the plugin development, see the [plugin example](https://github.com/dstackai/dstack/tree/master/examples/plugins/example_plugin).
116+
## Built-in Plugins
117+
118+
### REST Plugin
119+
`rest_plugin` is a builtin `dstack` plugin that allows writing your custom plugins as API servers, so you don't need to install plugins as Python packages.
120+
121+
Plugins implemented as API servers have advantages over plugins implemented as Python packages in some cases:
122+
* No dependency conflicts with `dstack`.
123+
* You can use any programming language.
124+
* If you run the `dstack` server via Docker, you don't need to extend the `dstack` server image with plugins or map them via volumes.
125+
126+
To get started, check out the [plugin server example](/examples/plugins/example_plugin_server/README.md).
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
## Overview
2+
3+
If you wish to hook up your own plugin server through `dstack` builtin `rest_plugin`, here's a basic example on how to do so.
4+
5+
## Steps
6+
7+
8+
1. Install required dependencies for the plugin server:
9+
10+
```bash
11+
uv sync
12+
```
13+
14+
1. Start the plugin server locally:
15+
16+
```bash
17+
fastapi dev app/main.py
18+
```
19+
20+
1. Enable `rest_plugin` in `server/config.yaml`:
21+
22+
```yaml
23+
plugins:
24+
- rest_plugin
25+
```
26+
27+
1. Point the `dstack` server to your plugin server:
28+
```bash
29+
export DSTACK_PLUGIN_SERVICE_URI=http://127.0.0.1:8000
30+
```

examples/plugins/example_plugin_server/app/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import logging
2+
3+
from fastapi import FastAPI
4+
5+
from app.utils import configure_logging
6+
from dstack.plugins.builtin.rest_plugin import (
7+
FleetSpecRequest,
8+
FleetSpecResponse,
9+
GatewaySpecRequest,
10+
GatewaySpecResponse,
11+
RunSpecRequest,
12+
RunSpecResponse,
13+
VolumeSpecRequest,
14+
VolumeSpecResponse,
15+
)
16+
17+
configure_logging()
18+
logger = logging.getLogger(__name__)
19+
20+
app = FastAPI()
21+
22+
23+
@app.post("/apply_policies/on_run_apply")
24+
async def on_run_apply(request: RunSpecRequest) -> RunSpecResponse:
25+
logger.info(
26+
f"Received run spec request from user {request.user} and project {request.project}"
27+
)
28+
response = RunSpecResponse(spec=request.spec, error=None)
29+
return response
30+
31+
32+
@app.post("/apply_policies/on_fleet_apply")
33+
async def on_fleet_apply(request: FleetSpecRequest) -> FleetSpecResponse:
34+
logger.info(
35+
f"Received fleet spec request from user {request.user} and project {request.project}"
36+
)
37+
response = FleetSpecResponse(request.spec, error=None)
38+
return response
39+
40+
41+
@app.post("/apply_policies/on_volume_apply")
42+
async def on_volume_apply(request: VolumeSpecRequest) -> VolumeSpecResponse:
43+
logger.info(
44+
f"Received volume spec request from user {request.user} and project {request.project}"
45+
)
46+
response = VolumeSpecResponse(request.spec, error=None)
47+
return response
48+
49+
50+
@app.post("/apply_policies/on_gateway_apply")
51+
async def on_gateway_apply(request: GatewaySpecRequest) -> GatewaySpecResponse:
52+
logger.info(
53+
f"Received gateway spec request from user {request.user} and project {request.project}"
54+
)
55+
response = GatewaySpecResponse(request.spec, error=None)
56+
return response
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
import logging
2+
import os
3+
4+
5+
def configure_logging():
6+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
7+
logging.basicConfig(level=log_level)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[project]
2+
name = "dstack-plugin-server"
3+
version = "0.1.0"
4+
description = "Example plugin server"
5+
readme = "README.md"
6+
requires-python = ">=3.11"
7+
dependencies = [
8+
"fastapi[standard]>=0.115.12",
9+
"dstack>=0.19.8"
10+
]

src/dstack/_internal/server/services/plugins.py

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import itertools
22
from importlib import import_module
3+
from typing import Dict
34

45
from backports.entry_points_selectable import entry_points # backport for Python 3.9
56

@@ -12,50 +13,80 @@
1213

1314
_PLUGINS: list[Plugin] = []
1415

16+
_BUILTIN_PLUGINS: Dict[str, str] = {"rest_plugin": "dstack.plugins.builtin.rest_plugin:RESTPlugin"}
1517

16-
def load_plugins(enabled_plugins: list[str]):
17-
_PLUGINS.clear()
18-
plugins_entrypoints = entry_points(group="dstack.plugins")
19-
plugins_to_load = enabled_plugins.copy()
20-
for entrypoint in plugins_entrypoints:
21-
if entrypoint.name not in enabled_plugins:
22-
logger.info(
23-
("Found not enabled plugin %s. Plugin will not be loaded."),
24-
entrypoint.name,
25-
)
26-
continue
18+
19+
class PluginEntrypoint:
20+
def __init__(self, name: str, import_path: str, is_builtin: bool = False):
21+
self.name = name
22+
self.import_path = import_path
23+
self.is_builtin = is_builtin
24+
25+
def load(self):
26+
module_path, _, class_name = self.import_path.partition(":")
2727
try:
28-
module_path, _, class_name = entrypoint.value.partition(":")
2928
module = import_module(module_path)
29+
plugin_class = getattr(module, class_name, None)
30+
if plugin_class is None:
31+
logger.warning(
32+
("Failed to load plugin %s: plugin class %s not found in module %s."),
33+
self.name,
34+
class_name,
35+
module_path,
36+
)
37+
return None
38+
if not issubclass(plugin_class, Plugin):
39+
logger.warning(
40+
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
41+
self.name,
42+
class_name,
43+
)
44+
return None
45+
return plugin_class()
3046
except ImportError:
3147
logger.warning(
3248
(
3349
"Failed to load plugin %s when importing %s."
3450
" Ensure the module is on the import path."
3551
),
36-
entrypoint.name,
37-
entrypoint.value,
52+
self.name,
53+
self.import_path,
3854
)
39-
continue
40-
plugin_class = getattr(module, class_name, None)
41-
if plugin_class is None:
42-
logger.warning(
43-
("Failed to load plugin %s: plugin class %s not found in module %s."),
55+
return None
56+
57+
58+
def load_plugins(enabled_plugins: list[str]):
59+
_PLUGINS.clear()
60+
entrypoints: dict[str, PluginEntrypoint] = {}
61+
plugins_to_load = enabled_plugins.copy()
62+
for entrypoint in entry_points(group="dstack.plugins"):
63+
if entrypoint.name not in enabled_plugins:
64+
logger.info(
65+
("Found not enabled plugin %s. Plugin will not be loaded."),
4466
entrypoint.name,
45-
class_name,
46-
module_path,
4767
)
4868
continue
49-
if not issubclass(plugin_class, Plugin):
50-
logger.warning(
51-
("Failed to load plugin %s: plugin class %s is not a subclass of Plugin."),
52-
entrypoint.name,
53-
class_name,
69+
else:
70+
entrypoints[entrypoint.name] = PluginEntrypoint(
71+
entrypoint.name, entrypoint.value, is_builtin=False
5472
)
55-
continue
56-
plugins_to_load.remove(entrypoint.name)
57-
_PLUGINS.append(plugin_class())
58-
logger.info("Loaded plugin %s", entrypoint.name)
73+
74+
for name, import_path in _BUILTIN_PLUGINS.items():
75+
if name not in enabled_plugins:
76+
logger.info(
77+
("Found not enabled builtin plugin %s. Plugin will not be loaded."),
78+
name,
79+
)
80+
else:
81+
entrypoints[name] = PluginEntrypoint(name, import_path, is_builtin=True)
82+
83+
for plugin_name, plugin_entrypoint in entrypoints.items():
84+
plugin_instance = plugin_entrypoint.load()
85+
if plugin_instance is not None:
86+
_PLUGINS.append(plugin_instance)
87+
plugins_to_load.remove(plugin_name)
88+
logger.info("Loaded plugin %s", plugin_name)
89+
5990
if plugins_to_load:
6091
logger.warning("Enabled plugins not found: %s", plugins_to_load)
6192

src/dstack/plugins/builtin/__init__.py

Whitespace-only changes.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# ruff: noqa: F401
2+
from dstack.plugins.builtin.rest_plugin._models import (
3+
FleetSpecRequest,
4+
FleetSpecResponse,
5+
GatewaySpecRequest,
6+
GatewaySpecResponse,
7+
RunSpecRequest,
8+
RunSpecResponse,
9+
SpecApplyRequest,
10+
SpecApplyResponse,
11+
VolumeSpecRequest,
12+
VolumeSpecResponse,
13+
)
14+
from dstack.plugins.builtin.rest_plugin._plugin import (
15+
PLUGIN_SERVICE_URI_ENV_VAR_NAME,
16+
CustomApplyPolicy,
17+
RESTPlugin,
18+
)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from typing import Generic, Optional, TypeVar
2+
3+
from pydantic import BaseModel, Field
4+
from typing_extensions import Annotated
5+
6+
from dstack._internal.core.models.fleets import FleetSpec
7+
from dstack._internal.core.models.gateways import GatewaySpec
8+
from dstack._internal.core.models.runs import RunSpec
9+
from dstack._internal.core.models.volumes import VolumeSpec
10+
11+
SpecType = TypeVar("SpecType", RunSpec, FleetSpec, VolumeSpec, GatewaySpec)
12+
13+
14+
class SpecApplyRequest(BaseModel, Generic[SpecType]):
15+
user: Annotated[str, Field(description="The name of the user making the apply request")]
16+
project: Annotated[str, Field(description="The name of the project the request is for")]
17+
spec: Annotated[SpecType, Field(description="The spec to be applied")]
18+
19+
# Override dict() to remove __orig_class__ attribute and avoid "TypeError: Object of type _GenericAlias is not JSON serializable"
20+
# error. This issue doesn't happen though when running the code in pytest, only when running the server.
21+
def dict(self, *args, **kwargs):
22+
d = super().dict(*args, **kwargs)
23+
d.pop("__orig_class__", None)
24+
return d
25+
26+
27+
RunSpecRequest = SpecApplyRequest[RunSpec]
28+
FleetSpecRequest = SpecApplyRequest[FleetSpec]
29+
VolumeSpecRequest = SpecApplyRequest[VolumeSpec]
30+
GatewaySpecRequest = SpecApplyRequest[GatewaySpec]
31+
32+
33+
class SpecApplyResponse(BaseModel, Generic[SpecType]):
34+
spec: Annotated[
35+
SpecType,
36+
Field(
37+
description="The spec to apply, original spec if error otherwise original or mutated by plugin service if approved"
38+
),
39+
]
40+
error: Annotated[
41+
Optional[str], Field(description="Error message if request is rejected", min_length=1)
42+
] = None
43+
44+
45+
RunSpecResponse = SpecApplyResponse[RunSpec]
46+
FleetSpecResponse = SpecApplyResponse[FleetSpec]
47+
VolumeSpecResponse = SpecApplyResponse[VolumeSpec]
48+
GatewaySpecResponse = SpecApplyResponse[GatewaySpec]

0 commit comments

Comments
 (0)