Skip to content

Commit 4dca262

Browse files
Ygnasopenshift-merge-bot[bot]
authored andcommitted
Add validation for Cluster configuration parameters
1 parent e7a45ba commit 4dca262

File tree

1 file changed

+59
-29
lines changed

1 file changed

+59
-29
lines changed

src/codeflare_sdk/cluster/config.py

+59-29
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
Cluster object.
1919
"""
2020

21-
from dataclasses import dataclass, field
2221
import pathlib
23-
import typing
2422
import warnings
23+
from dataclasses import dataclass, field, fields
24+
from typing import Dict, List, Optional, Union, get_args, get_origin
2525

2626
dir = pathlib.Path(__file__).parent.parent.resolve()
2727

@@ -73,43 +73,45 @@ class ClusterConfiguration:
7373
"""
7474

7575
name: str
76-
namespace: str = None
77-
head_info: list = field(default_factory=list)
78-
head_cpus: typing.Union[int, str] = 2
79-
head_memory: typing.Union[int, str] = 8
80-
head_gpus: int = None # Deprecating
81-
head_extended_resource_requests: typing.Dict[str, int] = field(default_factory=dict)
82-
machine_types: list = field(default_factory=list) # ["m4.xlarge", "g4dn.xlarge"]
83-
worker_cpu_requests: typing.Union[int, str] = 1
84-
worker_cpu_limits: typing.Union[int, str] = 1
85-
min_cpus: typing.Union[int, str] = None # Deprecating
86-
max_cpus: typing.Union[int, str] = None # Deprecating
76+
namespace: Optional[str] = None
77+
head_info: List[str] = field(default_factory=list)
78+
head_cpus: Union[int, str] = 2
79+
head_memory: Union[int, str] = 8
80+
head_gpus: Optional[int] = None # Deprecating
81+
head_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
82+
machine_types: List[str] = field(
83+
default_factory=list
84+
) # ["m4.xlarge", "g4dn.xlarge"]
85+
worker_cpu_requests: Union[int, str] = 1
86+
worker_cpu_limits: Union[int, str] = 1
87+
min_cpus: Optional[Union[int, str]] = None # Deprecating
88+
max_cpus: Optional[Union[int, str]] = None # Deprecating
8789
num_workers: int = 1
88-
worker_memory_requests: typing.Union[int, str] = 2
89-
worker_memory_limits: typing.Union[int, str] = 2
90-
min_memory: typing.Union[int, str] = None # Deprecating
91-
max_memory: typing.Union[int, str] = None # Deprecating
92-
num_gpus: int = None # Deprecating
90+
worker_memory_requests: Union[int, str] = 2
91+
worker_memory_limits: Union[int, str] = 2
92+
min_memory: Optional[Union[int, str]] = None # Deprecating
93+
max_memory: Optional[Union[int, str]] = None # Deprecating
94+
num_gpus: Optional[int] = None # Deprecating
9395
template: str = f"{dir}/templates/base-template.yaml"
9496
appwrapper: bool = False
95-
envs: dict = field(default_factory=dict)
97+
envs: Dict[str, str] = field(default_factory=dict)
9698
image: str = ""
97-
image_pull_secrets: list = field(default_factory=list)
99+
image_pull_secrets: List[str] = field(default_factory=list)
98100
write_to_file: bool = False
99101
verify_tls: bool = True
100-
labels: dict = field(default_factory=dict)
101-
worker_extended_resource_requests: typing.Dict[str, int] = field(
102-
default_factory=dict
103-
)
104-
extended_resource_mapping: typing.Dict[str, str] = field(default_factory=dict)
102+
labels: Dict[str, str] = field(default_factory=dict)
103+
worker_extended_resource_requests: Dict[str, int] = field(default_factory=dict)
104+
extended_resource_mapping: Dict[str, str] = field(default_factory=dict)
105105
overwrite_default_resource_mapping: bool = False
106+
local_queue: Optional[str] = None
106107

107108
def __post_init__(self):
108109
if not self.verify_tls:
109110
print(
110111
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
111112
)
112113

114+
self._validate_types()
113115
self._memory_to_string()
114116
self._str_mem_no_unit_add_GB()
115117
self._memory_to_resource()
@@ -139,9 +141,7 @@ def _combine_extended_resource_mapping(self):
139141
**self.extended_resource_mapping,
140142
}
141143

142-
def _validate_extended_resource_requests(
143-
self, extended_resources: typing.Dict[str, int]
144-
):
144+
def _validate_extended_resource_requests(self, extended_resources: Dict[str, int]):
145145
for k in extended_resources.keys():
146146
if k not in self.extended_resource_mapping.keys():
147147
raise ValueError(
@@ -206,4 +206,34 @@ def _memory_to_resource(self):
206206
warnings.warn("max_memory is being deprecated, use worker_memory_limits")
207207
self.worker_memory_limits = f"{self.max_memory}G"
208208

209-
local_queue: str = None
209+
def _validate_types(self):
210+
"""Validate the types of all fields in the ClusterConfiguration dataclass."""
211+
for field_info in fields(self):
212+
value = getattr(self, field_info.name)
213+
expected_type = field_info.type
214+
if not self._is_type(value, expected_type):
215+
raise TypeError(
216+
f"'{field_info.name}' should be of type {expected_type}"
217+
)
218+
219+
@staticmethod
220+
def _is_type(value, expected_type):
221+
"""Check if the value matches the expected type."""
222+
223+
def check_type(value, expected_type):
224+
origin_type = get_origin(expected_type)
225+
args = get_args(expected_type)
226+
if origin_type is Union:
227+
return any(check_type(value, union_type) for union_type in args)
228+
if origin_type is list:
229+
return all(check_type(elem, args[0]) for elem in value)
230+
if origin_type is dict:
231+
return all(
232+
check_type(k, args[0]) and check_type(v, args[1])
233+
for k, v in value.items()
234+
)
235+
if origin_type is tuple:
236+
return all(check_type(elem, etype) for elem, etype in zip(value, args))
237+
return isinstance(value, expected_type)
238+
239+
return check_type(value, expected_type)

0 commit comments

Comments
 (0)