Skip to content

[chore]: Refactor gateway modules #2226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/reference/dstack.yml/service.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The `service` configuration type allows running [services](../../concepts/servic

=== "OpenAI"

#SCHEMA# dstack._internal.core.models.gateways.OpenAIChatModel
#SCHEMA# dstack.api.OpenAIChatModel
overrides:
show_root_heading: false
type:
Expand All @@ -25,7 +25,7 @@ The `service` configuration type allows running [services](../../concepts/servic
> TGI provides an OpenAI-compatible API starting with version 1.4.0,
so models served by TGI can be defined with `format: openai` too.

#SCHEMA# dstack._internal.core.models.gateways.TGIChatModel
#SCHEMA# dstack.api.TGIChatModel
overrides:
show_root_heading: false
type:
Expand Down
3 changes: 2 additions & 1 deletion src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
from dstack._internal.core.models.common import CoreModel, Duration, RegistryAuth
from dstack._internal.core.models.envs import Env
from dstack._internal.core.models.fleets import FleetConfiguration
from dstack._internal.core.models.gateways import AnyModel, GatewayConfiguration, OpenAIChatModel
from dstack._internal.core.models.gateways import GatewayConfiguration
from dstack._internal.core.models.profiles import ProfileParams
from dstack._internal.core.models.repos.base import Repo
from dstack._internal.core.models.repos.virtual import VirtualRepo
from dstack._internal.core.models.resources import Range, ResourcesSpec
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point

Expand Down
68 changes: 0 additions & 68 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel

# TODO(#1595): refactor into different modules: gateway-specific and proxy-specific


class GatewayStatus(str, Enum):
SUBMITTED = "submitted"
Expand Down Expand Up @@ -110,69 +108,3 @@ class GatewayProvisioningData(CoreModel):
availability_zone: Optional[str] = None
hostname: Optional[str] = None
backend_data: Optional[str] = None # backend-specific data in json


class BaseChatModel(CoreModel):
type: Annotated[Literal["chat"], Field(description="The type of the model")] = "chat"
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[
str, Field(description="The serving format. Supported values include `openai` and `tgi`")
]


class TGIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template from the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default end of sentence token from the HuggingFace Hub configuration will be used.
"""

format: Annotated[
Literal["tgi"], Field(description="The serving format. Must be set to `tgi`")
]
chat_template: Annotated[
Optional[str],
Field(
description=(
"The custom prompt template for the model."
" If not specified, the default prompt template"
" from the HuggingFace Hub configuration will be used"
)
),
] = None # will be set before registering the service
eos_token: Annotated[
Optional[str],
Field(
description=(
"The custom end of sentence token."
" If not specified, the default end of sentence token"
" from the HuggingFace Hub configuration will be used"
)
),
] = None


class OpenAIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, i.e. "openai".
prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`.
"""

format: Annotated[
Literal["openai"], Field(description="The serving format. Must be set to `openai`")
]
prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1"


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Union[ChatModel] # embeddings and etc.
76 changes: 76 additions & 0 deletions src/dstack/_internal/core/models/services.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Data structures related to `type: service` runs.
"""

from typing import Optional, Union

from pydantic import Field
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.common import CoreModel


class BaseChatModel(CoreModel):
type: Annotated[Literal["chat"], Field(description="The type of the model")] = "chat"
name: Annotated[str, Field(description="The name of the model")]
format: Annotated[
str, Field(description="The serving format. Supported values include `openai` and `tgi`")
]


class TGIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, e.g. "tgi" if the model is served with HuggingFace's Text Generation Inference.
chat_template (Optional[str]): The custom prompt template for the model. If not specified, the default prompt template from the HuggingFace Hub configuration will be used.
eos_token (Optional[str]): The custom end of sentence token. If not specified, the default end of sentence token from the HuggingFace Hub configuration will be used.
"""

format: Annotated[
Literal["tgi"], Field(description="The serving format. Must be set to `tgi`")
]
chat_template: Annotated[
Optional[str],
Field(
description=(
"The custom prompt template for the model."
" If not specified, the default prompt template"
" from the HuggingFace Hub configuration will be used"
)
),
] = None # will be set before registering the service
eos_token: Annotated[
Optional[str],
Field(
description=(
"The custom end of sentence token."
" If not specified, the default end of sentence token"
" from the HuggingFace Hub configuration will be used"
)
),
] = None


class OpenAIChatModel(BaseChatModel):
"""
Mapping of the model for the OpenAI-compatible endpoint.
Attributes:
type (str): The type of the model, e.g. "chat"
name (str): The name of the model. This name will be used both to load model configuration from the HuggingFace Hub and in the OpenAI-compatible endpoint.
format (str): The format of the model, i.e. "openai".
prefix (str): The `base_url` prefix: `http://hostname/{prefix}/chat/completions`. Defaults to `/v1`.
"""

format: Annotated[
Literal["openai"], Field(description="The serving format. Must be set to `openai`")
]
prefix: Annotated[str, Field(description="The `base_url` prefix (after hostname)")] = "/v1"


ChatModel = Annotated[Union[TGIChatModel, OpenAIChatModel], Field(discriminator="format")]
AnyModel = Union[ChatModel] # embeddings and etc.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

import dstack._internal.server.services.gateways as gateways
from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT
from dstack._internal.core.errors import GatewayError
from dstack._internal.core.models.backends.base import BackendType
Expand All @@ -32,6 +31,7 @@
)
from dstack._internal.server.schemas.runner import TaskStatus
from dstack._internal.server.services import logs as logs_services
from dstack._internal.server.services import services
from dstack._internal.server.services.jobs import (
find_job,
get_job_runtime_data,
Expand Down Expand Up @@ -313,7 +313,7 @@ async def _process_running_job(session: AsyncSession, job_model: JobModel):
and run.run_spec.configuration.type == "service"
):
try:
await gateways.register_replica(session, run_model.gateway_id, run, job_model)
await services.register_replica(session, run_model.gateway_id, run, job_model)
except GatewayError as e:
logger.warning(
"%s: failed to register service replica: %s, age=%s",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sqlalchemy.orm import joinedload, selectinload

import dstack._internal.server.services.gateways as gateways
import dstack._internal.server.services.gateways.autoscalers as autoscalers
import dstack._internal.server.services.services.autoscalers as autoscalers
from dstack._internal.core.errors import ServerError
from dstack._internal.core.models.profiles import RetryEvent
from dstack._internal.core.models.runs import (
Expand Down
Loading
Loading