Skip to content

Commit da3041d

Browse files
committed
feat: add custom volumes/volume mounts for ray clusters
1 parent 6b0a3cc commit da3041d

File tree

2 files changed

+82
-52
lines changed

2 files changed

+82
-52
lines changed

Diff for: src/codeflare_sdk/ray/cluster/build_ray_cluster.py

+23-3
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers):
249249
"""
250250
pod_spec = V1PodSpec(
251251
containers=containers,
252-
volumes=VOLUMES,
252+
volumes=generate_custom_storage(cluster.config.volumes, VOLUMES),
253253
)
254254
if cluster.config.image_pull_secrets != []:
255255
pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster)
@@ -295,7 +295,9 @@ def get_head_container_spec(
295295
cluster.config.head_memory_limits,
296296
cluster.config.head_extended_resource_requests,
297297
),
298-
volume_mounts=VOLUME_MOUNTS,
298+
volume_mounts=generate_custom_storage(
299+
cluster.config.volume_mounts, VOLUME_MOUNTS
300+
),
299301
)
300302
if cluster.config.envs != {}:
301303
head_container.env = generate_env_vars(cluster)
@@ -337,7 +339,9 @@ def get_worker_container_spec(
337339
cluster.config.worker_memory_limits,
338340
cluster.config.worker_extended_resource_requests,
339341
),
340-
volume_mounts=VOLUME_MOUNTS,
342+
volume_mounts=generate_custom_storage(
343+
cluster.config.volume_mounts, VOLUME_MOUNTS
344+
),
341345
)
342346

343347
if cluster.config.envs != {}:
@@ -521,6 +525,22 @@ def wrap_cluster(
521525

522526

523527
# Etc.
528+
def generate_custom_storage(provided_storage: list, default_storage: list):
529+
"""
530+
The generate_custom_storage function updates the volumes/volume mounts configs with the default volumes/volume mounts.
531+
"""
532+
storage_list = provided_storage.copy()
533+
534+
if storage_list == []:
535+
storage_list = default_storage
536+
else:
537+
# We append the list of volumes/volume mounts with the defaults and return the full list
538+
for storage in default_storage:
539+
storage_list.append(storage)
540+
541+
return storage_list
542+
543+
524544
def write_to_file(cluster: "codeflare_sdk.ray.cluster.Cluster", resource: dict):
525545
"""
526546
The write_to_file function writes the built Ray Cluster/AppWrapper dict as a yaml file in the .codeflare folder

Diff for: src/codeflare_sdk/ray/cluster/config.py

+59-49
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import warnings
2323
from dataclasses import dataclass, field, fields
2424
from typing import Dict, List, Optional, Union, get_args, get_origin
25+
from kubernetes.client import V1Volume, V1VolumeMount
2526

2627
dir = pathlib.Path(__file__).parent.parent.resolve()
2728

@@ -41,56 +42,63 @@
4142
@dataclass
4243
class ClusterConfiguration:
4344
"""
44-
This dataclass is used to specify resource requirements and other details, and
45-
is passed in as an argument when creating a Cluster object.
45+
This dataclass is used to specify resource requirements and other details, and
46+
is passed in as an argument when creating a Cluster object.
4647
47-
Args:
48-
name:
49-
The name of the cluster.
50-
namespace:
51-
The namespace in which the cluster should be created.
52-
head_cpus:
53-
The number of CPUs to allocate to the head node.
54-
head_memory:
55-
The amount of memory to allocate to the head node.
56-
head_gpus:
57-
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
58-
head_extended_resource_requests:
59-
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
60-
min_cpus:
61-
The minimum number of CPUs to allocate to each worker.
62-
max_cpus:
63-
The maximum number of CPUs to allocate to each worker.
64-
num_workers:
65-
The number of workers to create.
66-
min_memory:
67-
The minimum amount of memory to allocate to each worker.
68-
max_memory:
69-
The maximum amount of memory to allocate to each worker.
70-
num_gpus:
71-
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
72-
appwrapper:
73-
A boolean indicating whether to use an AppWrapper.
74-
envs:
75-
A dictionary of environment variables to set for the cluster.
76-
image:
77-
The image to use for the cluster.
78-
image_pull_secrets:
79-
A list of image pull secrets to use for the cluster.
80-
write_to_file:
81-
A boolean indicating whether to write the cluster configuration to a file.
82-
verify_tls:
83-
A boolean indicating whether to verify TLS when connecting to the cluster.
84-
labels:
85-
A dictionary of labels to apply to the cluster.
86-
worker_extended_resource_requests:
87-
A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
88-
extended_resource_mapping:
89-
A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
90-
overwrite_default_resource_mapping:
91-
A boolean indicating whether to overwrite the default resource mapping.
92-
annotations:
93-
A dictionary of annotations to apply to the cluster.
48+
Args:
49+
name:
50+
The name of the cluster.
51+
namespace:
52+
The namespace in which the cluster should be created.
53+
head_cpus:
54+
The number of CPUs to allocate to the head node.
55+
head_memory:
56+
The amount of memory to allocate to the head node.
57+
head_gpus:
58+
The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests)
59+
head_extended_resource_requests:
60+
A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1}
61+
min_cpus:
62+
The minimum number of CPUs to allocate to each worker.
63+
max_cpus:
64+
The maximum number of CPUs to allocate to each worker.
65+
num_workers:
66+
The number of workers to create.
67+
min_memory:
68+
The minimum amount of memory to allocate to each worker.
69+
max_memory:
70+
The maximum amount of memory to allocate to each worker.
71+
num_gpus:
72+
The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests)
73+
appwrapper:
74+
A boolean indicating whether to use an AppWrapper.
75+
envs:
76+
A dictionary of environment variables to set for the cluster.
77+
image:
78+
The image to use for the cluster.
79+
image_pull_secrets:
80+
A list of image pull secrets to use for the cluster.
81+
write_to_file:
82+
A boolean indicating whether to write the cluster configuration to a file.
83+
verify_tls:
84+
A boolean indicating whether to verify TLS when connecting to the cluster.
85+
labels:
86+
A dictionary of labels to apply to the cluster.
87+
worker_extended_resource_requests:
88+
A dictionary of extended resource requests for each worker. ex: {"nvidia.com/gpu": 1}
89+
extended_resource_mapping:
90+
A dictionary of custom resource mappings to map extended resource requests to RayCluster resource names
91+
overwrite_default_resource_mapping:
92+
A boolean indicating whether to overwrite the default resource mapping.
93+
<<<<<<< HEAD
94+
annotations:
95+
A dictionary of annotations to apply to the cluster.
96+
=======
97+
volumes:
98+
A list of V1Volume objects to add to the Cluster
99+
volume_mounts:
100+
A list of V1VolumeMount objects to add to the Cluster
101+
>>>>>>> 625b209 (feat: add custom volumes/volume mounts for ray clusters)
94102
"""
95103

96104
name: str
@@ -129,6 +137,8 @@ class ClusterConfiguration:
129137
overwrite_default_resource_mapping: bool = False
130138
local_queue: Optional[str] = None
131139
annotations: Dict[str, str] = field(default_factory=dict)
140+
volumes: list[V1Volume] = field(default_factory=list)
141+
volume_mounts: list[V1VolumeMount] = field(default_factory=list)
132142

133143
def __post_init__(self):
134144
if not self.verify_tls:

0 commit comments

Comments
 (0)