|
18 | 18 | Cluster object.
|
19 | 19 | """
|
20 | 20 |
|
21 |
| -from dataclasses import dataclass, field |
22 | 21 | import pathlib
|
23 |
| -import typing |
24 | 22 | import warnings
|
| 23 | +from dataclasses import dataclass, field, fields |
| 24 | +from typing import Dict, List, Optional, Union, get_args, get_origin |
25 | 25 |
|
26 | 26 | dir = pathlib.Path(__file__).parent.parent.resolve()
|
27 | 27 |
|
@@ -73,43 +73,45 @@ class ClusterConfiguration:
|
73 | 73 | """
|
74 | 74 |
|
75 | 75 | name: str
|
76 |
| - namespace: str = None |
77 |
| - head_info: list = field(default_factory=list) |
78 |
| - head_cpus: typing.Union[int, str] = 2 |
79 |
| - head_memory: typing.Union[int, str] = 8 |
80 |
| - head_gpus: int = None # Deprecating |
81 |
| - head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict) |
82 |
| - machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"] |
83 |
| - worker_cpu_requests: typing.Union[int, str] = 1 |
84 |
| - worker_cpu_limits: typing.Union[int, str] = 1 |
85 |
| - min_cpus: typing.Union[int, str] = None # Deprecating |
86 |
| - max_cpus: typing.Union[int, str] = None # Deprecating |
| 76 | + namespace: Optional[str] = None |
| 77 | + head_info: List[str] = field(default_factory=list) |
| 78 | + head_cpus: Union[int, str] = 2 |
| 79 | + head_memory: Union[int, str] = 8 |
| 80 | + head_gpus: Optional[int] = None # Deprecating |
| 81 | + head_extended_resource_requests: Dict[str, int] = field(default_factory=dict) |
| 82 | + machine_types: List[str] = field( |
| 83 | + default_factory=list |
| 84 | + ) # ["m4.xlarge", "g4dn.xlarge"] |
| 85 | + worker_cpu_requests: Union[int, str] = 1 |
| 86 | + worker_cpu_limits: Union[int, str] = 1 |
| 87 | + min_cpus: Optional[Union[int, str]] = None # Deprecating |
| 88 | + max_cpus: Optional[Union[int, str]] = None # Deprecating |
87 | 89 | num_workers: int = 1
|
88 |
| - worker_memory_requests: typing.Union[int, str] = 2 |
89 |
| - worker_memory_limits: typing.Union[int, str] = 2 |
90 |
| - min_memory: typing.Union[int, str] = None # Deprecating |
91 |
| - max_memory: typing.Union[int, str] = None # Deprecating |
92 |
| - num_gpus: int = None # Deprecating |
| 90 | + worker_memory_requests: Union[int, str] = 2 |
| 91 | + worker_memory_limits: Union[int, str] = 2 |
| 92 | + min_memory: Optional[Union[int, str]] = None # Deprecating |
| 93 | + max_memory: Optional[Union[int, str]] = None # Deprecating |
| 94 | + num_gpus: Optional[int] = None # Deprecating |
93 | 95 | template: str = f"{dir}/templates/base-template.yaml"
|
94 | 96 | appwrapper: bool = False
|
95 |
| - envs: dict = field(default_factory=dict) |
| 97 | + envs: Dict[str, str] = field(default_factory=dict) |
96 | 98 | image: str = ""
|
97 |
| - image_pull_secrets: list = field(default_factory=list) |
| 99 | + image_pull_secrets: List[str] = field(default_factory=list) |
98 | 100 | write_to_file: bool = False
|
99 | 101 | verify_tls: bool = True
|
100 |
| - labels: dict = field(default_factory=dict) |
101 |
| - worker_extended_resource_requests: typing.Dict[str, int] = field( |
102 |
| - default_factory=dict |
103 |
| - ) |
104 |
| - extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict) |
| 102 | + labels: Dict[str, str] = field(default_factory=dict) |
| 103 | + worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict) |
| 104 | + extended_resource_mapping: Dict[str, str] = field(default_factory=dict) |
105 | 105 | overwrite_default_resource_mapping: bool = False
|
| 106 | + local_queue: Optional[str] = None |
106 | 107 |
|
107 | 108 | def __post_init__(self):
|
108 | 109 | if not self.verify_tls:
|
109 | 110 | print(
|
110 | 111 | "Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
|
111 | 112 | )
|
112 | 113 |
|
| 114 | + self._validate_types() |
113 | 115 | self._memory_to_string()
|
114 | 116 | self._str_mem_no_unit_add_GB()
|
115 | 117 | self._memory_to_resource()
|
@@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self):
|
139 | 141 | **self.extended_resource_mapping,
|
140 | 142 | }
|
141 | 143 |
|
142 |
| - def _validate_extended_resource_requests( |
143 |
| - self, extended_resources: typing.Dict[str, int] |
144 |
| - ): |
| 144 | + def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]): |
145 | 145 | for k in extended_resources.keys():
|
146 | 146 | if k not in self.extended_resource_mapping.keys():
|
147 | 147 | raise ValueError(
|
@@ -206,4 +206,34 @@ def _memory_to_resource(self):
|
206 | 206 | warnings.warn("max_memory is being deprecated, use worker_memory_limits")
|
207 | 207 | self.worker_memory_limits = f"{self.max_memory}G"
|
208 | 208 |
|
209 |
| - local_queue: str = None |
| 209 | + def _validate_types(self): |
| 210 | + """Validate the types of all fields in the ClusterConfiguration dataclass.""" |
| 211 | + for field_info in fields(self): |
| 212 | + value = getattr(self, field_info.name) |
| 213 | + expected_type = field_info.type |
| 214 | + if not self._is_type(value, expected_type): |
| 215 | + raise TypeError( |
| 216 | + f"'{field_info.name}' should be of type {expected_type}" |
| 217 | + ) |
| 218 | + |
| 219 | + @staticmethod |
| 220 | + def _is_type(value, expected_type): |
| 221 | + """Check if the value matches the expected type.""" |
| 222 | + |
| 223 | + def check_type(value, expected_type): |
| 224 | + origin_type = get_origin(expected_type) |
| 225 | + args = get_args(expected_type) |
| 226 | + if origin_type is Union: |
| 227 | + return any(check_type(value, union_type) for union_type in args) |
| 228 | + if origin_type is list: |
| 229 | + return all(check_type(elem, args[0]) for elem in value) |
| 230 | + if origin_type is dict: |
| 231 | + return all( |
| 232 | + check_type(k, args[0]) and check_type(v, args[1]) |
| 233 | + for k, v in value.items() |
| 234 | + ) |
| 235 | + if origin_type is tuple: |
| 236 | + return all(check_type(elem, etype) for elem, etype in zip(value, args)) |
| 237 | + return isinstance(value, expected_type) |
| 238 | + |
| 239 | + return check_type(value, expected_type) |
0 commit comments