Skip to content

🎨 add num_of_seats to pricing unit (for LICENSE type pricing plan) #7271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@
from decimal import Decimal
from typing import NamedTuple

from pydantic import BaseModel, ConfigDict, PositiveInt
from pydantic import BaseModel, ConfigDict, PositiveInt, model_validator

from ..resource_tracker import (
HardwareInfo,
PricingPlanClassification,
PricingPlanId,
PricingUnitCostId,
PricingUnitId,
UnitExtraInfo,
UnitExtraInfoLicense,
UnitExtraInfoTier,
)
from ..services_types import ServiceKey, ServiceVersion


class PricingUnitGet(BaseModel):
class RutPricingUnitGet(BaseModel):
pricing_unit_id: PricingUnitId
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
current_cost_per_unit: Decimal
current_cost_per_unit_id: PricingUnitCostId
default: bool
Expand All @@ -30,30 +31,68 @@ class PricingUnitGet(BaseModel):
{
"pricing_unit_id": 1,
"unit_name": "SMALL",
"unit_extra_info": UnitExtraInfo.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"unit_extra_info": UnitExtraInfoTier.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"current_cost_per_unit": 5.7,
"current_cost_per_unit_id": 1,
"default": True,
"specific_info": hw_config_example,
}
for hw_config_example in HardwareInfo.model_config["json_schema_extra"][
for hw_config_example in HardwareInfo.model_config["json_schema_extra"][ # type: ignore[index,union-attr]
"examples"
] # type: ignore[index,union-attr]
]
]
+ [
{
"pricing_unit_id": 2,
"unit_name": "5 seats",
"unit_extra_info": UnitExtraInfoLicense.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"current_cost_per_unit": 10.5,
"current_cost_per_unit_id": 2,
"default": False,
"specific_info": HardwareInfo.model_config["json_schema_extra"][ # type: ignore[index,union-attr]
"examples"
][
1
],
}
]
}
)


class PricingPlanGet(BaseModel):
class RutPricingPlanGet(BaseModel):
pricing_plan_id: PricingPlanId
display_name: str
description: str
classification: PricingPlanClassification
created_at: datetime
pricing_plan_key: str
pricing_units: list[PricingUnitGet] | None
pricing_units: list[RutPricingUnitGet] | None
is_active: bool

@model_validator(mode="after")
def ensure_classification_matches_extra_info(self):
"""Enforce that all PricingUnitGet.unit_extra_info match the plan's classification."""
if not self.pricing_units:
return self # No units to check

for unit in self.pricing_units:
if (
self.classification == PricingPlanClassification.TIER
and not isinstance(unit.unit_extra_info, UnitExtraInfoTier)
):
error_message = (
"For TIER classification, unit_extra_info must be UnitExtraInfoTier"
)
raise ValueError(error_message)
if (
self.classification == PricingPlanClassification.LICENSE
and not isinstance(unit.unit_extra_info, UnitExtraInfoLicense)
):
error_message = "For LICENSE classification, unit_extra_info must be UnitExtraInfoLicense"
raise ValueError(error_message)
return self

model_config = ConfigDict(
json_schema_extra={
"examples": [
Expand All @@ -64,21 +103,34 @@ class PricingPlanGet(BaseModel):
"classification": "TIER",
"created_at": "2023-01-11 13:11:47.293595",
"pricing_plan_key": "pricing-plan-sleeper",
"pricing_units": [pricing_unit_get_example],
"pricing_units": [
RutPricingUnitGet.model_config["json_schema_extra"]["examples"][
0
]
],
"is_active": True,
}
for pricing_unit_get_example in PricingUnitGet.model_config[
"json_schema_extra"
][
"examples"
] # type: ignore[index,union-attr]
},
{
"pricing_plan_id": 2,
"display_name": "VIP model A",
"description": "Special Pricing Plan for VIP",
"classification": "LICENSE",
"created_at": "2023-01-11 13:11:47.293595",
"pricing_plan_key": "vip-model-a",
"pricing_units": [
RutPricingUnitGet.model_config["json_schema_extra"]["examples"][
2
]
],
"is_active": True,
},
]
}
)


class PricingPlanPage(NamedTuple):
items: list[PricingPlanGet]
class RutPricingPlanPage(NamedTuple):
items: list[RutPricingPlanGet]
total: PositiveInt


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
PricingUnitId,
ServiceRunStatus,
SpecificInfo,
UnitExtraInfo,
UnitExtraInfoLicense,
UnitExtraInfoTier,
)
from ..services import ServiceKey, ServiceVersion
from ..services_types import ServiceRunID
Expand Down Expand Up @@ -48,7 +49,7 @@ class ServiceRunGet(
class PricingUnitGet(OutputSchema):
pricing_unit_id: PricingUnitId
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
current_cost_per_unit: Decimal
default: bool

Expand Down Expand Up @@ -114,7 +115,7 @@ class UpdatePricingPlanBodyParams(InputSchema):

class CreatePricingUnitBodyParams(InputSchema):
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
default: bool
specific_info: SpecificInfo
cost_per_unit: Decimal
Expand All @@ -128,7 +129,7 @@ class CreatePricingUnitBodyParams(InputSchema):

class UpdatePricingUnitBodyParams(InputSchema):
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
default: bool
specific_info: SpecificInfo
pricing_unit_cost_update: PricingUnitCostUpdate | None = Field(default=None)
Expand Down
31 changes: 25 additions & 6 deletions packages/models-library/src/models_library/resource_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class SpecificInfo(HardwareInfo):
to store aws ec2 instance type."""


class UnitExtraInfo(BaseModel):
class UnitExtraInfoTier(BaseModel):
"""Custom information that is propagated to the frontend. Defined fields are mandatory."""

CPU: NonNegativeInt
Expand All @@ -256,10 +256,29 @@ class UnitExtraInfo(BaseModel):
)


class UnitExtraInfoLicense(BaseModel):
"""Custom information that is propagated to the frontend. Defined fields are mandatory."""

num_of_seats: NonNegativeInt

model_config = ConfigDict(
populate_by_name=True,
extra="allow",
json_schema_extra={
"examples": [
{
"num_of_seats": 5,
"custom key": "custom value",
}
]
},
)


class PricingUnitWithCostCreate(BaseModel):
pricing_plan_id: PricingPlanId
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
default: bool
specific_info: SpecificInfo
cost_per_unit: Decimal
Expand All @@ -271,7 +290,7 @@ class PricingUnitWithCostCreate(BaseModel):
{
"pricing_plan_id": 1,
"unit_name": "My pricing plan",
"unit_extra_info": UnitExtraInfo.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"unit_extra_info": UnitExtraInfoTier.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"default": True,
"specific_info": {"aws_ec2_instances": ["t3.medium"]},
"cost_per_unit": 10,
Expand All @@ -291,7 +310,7 @@ class PricingUnitWithCostUpdate(BaseModel):
pricing_plan_id: PricingPlanId
pricing_unit_id: PricingUnitId
unit_name: str
unit_extra_info: UnitExtraInfo
unit_extra_info: UnitExtraInfoTier | UnitExtraInfoLicense
default: bool
specific_info: SpecificInfo
pricing_unit_cost_update: PricingUnitCostUpdate | None
Expand All @@ -303,7 +322,7 @@ class PricingUnitWithCostUpdate(BaseModel):
"pricing_plan_id": 1,
"pricing_unit_id": 1,
"unit_name": "My pricing plan",
"unit_extra_info": UnitExtraInfo.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"unit_extra_info": UnitExtraInfoTier.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"default": True,
"specific_info": {"aws_ec2_instances": ["t3.medium"]},
"pricing_unit_cost_update": {
Expand All @@ -315,7 +334,7 @@ class PricingUnitWithCostUpdate(BaseModel):
"pricing_plan_id": 1,
"pricing_unit_id": 1,
"unit_name": "My pricing plan",
"unit_extra_info": UnitExtraInfo.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"unit_extra_info": UnitExtraInfoTier.model_config["json_schema_extra"]["examples"][0], # type: ignore [index]
"default": True,
"specific_info": {"aws_ec2_instances": ["t3.medium"]},
"pricing_unit_cost_update": None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
)
from models_library.api_schemas_resource_usage_tracker.pricing_plans import (
PricingPlanGet,
PricingPlanPage,
PricingPlanToServiceGet,
RutPricingPlanGet,
RutPricingPlanPage,
)
from models_library.products import ProductName
from models_library.rabbitmq_basic_types import RPCMethodName
Expand Down Expand Up @@ -36,38 +36,40 @@ async def get_pricing_plan(
*,
product_name: ProductName,
pricing_plan_id: PricingPlanId,
) -> PricingPlanGet:
result: PricingPlanGet = await rabbitmq_rpc_client.request(
) -> RutPricingPlanGet:
result: RutPricingPlanGet = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("get_pricing_plan"),
product_name=product_name,
pricing_plan_id=pricing_plan_id,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, PricingPlanGet) # nosec
assert isinstance(result, RutPricingPlanGet) # nosec
return result


@log_decorator(_logger, level=logging.DEBUG)
async def list_pricing_plans(
async def list_pricing_plans_without_pricing_units(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
product_name: ProductName,
exclude_inactive: bool = True,
# pagination
offset: int = 0,
limit: int = 20,
) -> PricingPlanPage:
) -> RutPricingPlanPage:
result = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("list_pricing_plans"),
_RPC_METHOD_NAME_ADAPTER.validate_python(
"list_pricing_plans_without_pricing_units"
),
product_name=product_name,
exclude_inactive=exclude_inactive,
offset=offset,
limit=limit,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, PricingPlanPage) # nosec
assert isinstance(result, RutPricingPlanPage) # nosec
return result


Expand All @@ -76,14 +78,14 @@ async def create_pricing_plan(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
data: PricingPlanCreate,
) -> PricingPlanGet:
result: PricingPlanGet = await rabbitmq_rpc_client.request(
) -> RutPricingPlanGet:
result: RutPricingPlanGet = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("create_pricing_plan"),
data=data,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, PricingPlanGet) # nosec
assert isinstance(result, RutPricingPlanGet) # nosec
return result


Expand All @@ -93,15 +95,15 @@ async def update_pricing_plan(
*,
product_name: ProductName,
data: PricingPlanUpdate,
) -> PricingPlanGet:
result: PricingPlanGet = await rabbitmq_rpc_client.request(
) -> RutPricingPlanGet:
result: RutPricingPlanGet = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("update_pricing_plan"),
product_name=product_name,
data=data,
timeout_s=_DEFAULT_TIMEOUT_S,
)
assert isinstance(result, PricingPlanGet) # nosec
assert isinstance(result, RutPricingPlanGet) # nosec
return result


Expand All @@ -112,7 +114,7 @@ async def list_connected_services_to_pricing_plan_by_pricing_plan(
product_name: ProductName,
pricing_plan_id: PricingPlanId,
) -> list[PricingPlanToServiceGet]:
result: PricingPlanGet = await rabbitmq_rpc_client.request(
result: RutPricingPlanGet = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python(
"list_connected_services_to_pricing_plan_by_pricing_plan"
Expand All @@ -134,7 +136,7 @@ async def connect_service_to_pricing_plan(
service_key: ServiceKey,
service_version: ServiceVersion,
) -> PricingPlanToServiceGet:
result: PricingPlanGet = await rabbitmq_rpc_client.request(
result: RutPricingPlanGet = await rabbitmq_rpc_client.request(
RESOURCE_USAGE_TRACKER_RPC_NAMESPACE,
_RPC_METHOD_NAME_ADAPTER.validate_python("connect_service_to_pricing_plan"),
product_name=product_name,
Expand Down
Loading
Loading