Skip to content

Commit 4ebb3dd

Browse files
committed
feat: enable Ray cluster head pod persistency
Signed-off-by: kramaranya <[email protected]>
1 parent c311665 commit 4ebb3dd

File tree

2 files changed

+57
-4
lines changed

2 files changed

+57
-4
lines changed

src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+25
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,31 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"):
170170
},
171171
}
172172

173+
if cluster.config.enable_gcs_ft:
174+
if not cluster.config.redis_address:
175+
raise ValueError(
176+
"redis_address must be provided when enable_gcs_ft is True"
177+
)
178+
179+
gcs_ft_options = {"redisAddress": cluster.config.redis_address}
180+
181+
if cluster.config.external_storage_namespace:
182+
gcs_ft_options[
183+
"externalStorageNamespace"
184+
] = cluster.config.external_storage_namespace
185+
186+
if cluster.config.redis_password_secret:
187+
gcs_ft_options["redisPassword"] = {
188+
"valueFrom": {
189+
"secretKeyRef": {
190+
"name": cluster.config.redis_password_secret["name"],
191+
"key": cluster.config.redis_password_secret["key"],
192+
}
193+
}
194+
}
195+
196+
resource["spec"]["gcsFaultToleranceOptions"] = gcs_ft_options
197+
173198
config_check()
174199
k8s_client = get_api_client() or client.ApiClient()
175200

src/codeflare_sdk/ray/cluster/config.py

+32-4
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,38 @@ class ClusterConfiguration:
142142
annotations: Dict[str, str] = field(default_factory=dict)
143143
volumes: list[V1Volume] = field(default_factory=list)
144144
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
145+
enable_gcs_ft: bool = False
146+
redis_address: Optional[str] = None
147+
redis_password_secret: Optional[Dict[str, str]] = None
148+
external_storage_namespace: Optional[str] = None
145149

146150
def __post_init__(self):
147151
if not self.verify_tls:
148152
print(
149153
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
150154
)
151155

156+
if self.enable_gcs_ft:
157+
if not self.redis_address:
158+
raise ValueError(
159+
"redis_address must be provided when enable_gcs_ft is True"
160+
)
161+
162+
if self.redis_password_secret and not isinstance(
163+
self.redis_password_secret, dict
164+
):
165+
raise ValueError(
166+
"redis_password_secret must be a dictionary with 'name' and 'key' fields"
167+
)
168+
169+
if self.redis_password_secret and (
170+
"name" not in self.redis_password_secret
171+
or "key" not in self.redis_password_secret
172+
):
173+
raise ValueError(
174+
"redis_password_secret must contain both 'name' and 'key' fields"
175+
)
176+
152177
self._validate_types()
153178
self._memory_to_resource()
154179
self._memory_to_string()
@@ -283,10 +308,13 @@ def check_type(value, expected_type):
283308
else:
284309
return True
285310
if origin_type is dict:
286-
return all(
287-
check_type(k, args[0]) and check_type(v, args[1])
288-
for k, v in value.items()
289-
)
311+
if value is not None:
312+
return all(
313+
check_type(k, args[0]) and check_type(v, args[1])
314+
for k, v in value.items()
315+
)
316+
else:
317+
return True
290318
if origin_type is tuple:
291319
return all(check_type(elem, etype) for elem, etype in zip(value, args))
292320
if expected_type is int:

0 commit comments

Comments
 (0)