Skip to content

Commit 0bc477d

Browse files
committed
Add arguments to pass Ray cluster head and worker templates
1 parent e863e29 commit 0bc477d

File tree

6 files changed

+79
-15
lines changed

6 files changed

+79
-15
lines changed

poetry.lock

+12-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ cryptography = "40.0.2"
2929
executing = "1.2.0"
3030
pydantic = "< 2"
3131
ipywidgets = "8.1.2"
32+
mergedeep = "1.3.4"
3233

3334
[tool.poetry.group.docs]
3435
optional = true

src/codeflare_sdk/cluster/cluster.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,9 @@
1818
cluster setup queue, a list of all existing clusters, and the user's working namespace.
1919
"""
2020

21-
import re
2221
from time import sleep
2322
from typing import List, Optional, Tuple, Dict
2423

25-
from kubernetes import config
2624
from ray.job_submission import JobSubmissionClient
2725

2826
from .auth import config_check, api_config_handler
@@ -41,13 +39,11 @@
4139
RayCluster,
4240
RayClusterStatus,
4341
)
44-
from kubernetes import client, config
45-
from kubernetes.utils import parse_quantity
4642
import yaml
4743
import os
4844
import requests
4945

50-
from kubernetes import config
46+
from kubernetes import client, config
5147
from kubernetes.client.rest import ApiException
5248

5349

@@ -145,6 +141,8 @@ def create_app_wrapper(self):
145141
gpu = self.config.num_gpus
146142
workers = self.config.num_workers
147143
template = self.config.template
144+
head_template = self.config.head_template
145+
worker_template = self.config.worker_template
148146
image = self.config.image
149147
appwrapper = self.config.appwrapper
150148
env = self.config.envs
@@ -167,6 +165,8 @@ def create_app_wrapper(self):
167165
gpu=gpu,
168166
workers=workers,
169167
template=template,
168+
head_template=head_template,
169+
worker_template=worker_template,
170170
image=image,
171171
appwrapper=appwrapper,
172172
env=env,

src/codeflare_sdk/cluster/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import pathlib
2323
import typing
2424

25+
import kubernetes
26+
2527
dir = pathlib.Path(__file__).parent.parent.resolve()
2628

2729

@@ -46,6 +48,8 @@ class ClusterConfiguration:
4648
max_memory: typing.Union[int, str] = 2
4749
num_gpus: int = 0
4850
template: str = f"{dir}/templates/base-template.yaml"
51+
head_template: kubernetes.client.V1PodTemplateSpec = None
52+
worker_template: kubernetes.client.V1PodTemplateSpec = None
4953
appwrapper: bool = False
5054
envs: dict = field(default_factory=dict)
5155
image: str = ""

src/codeflare_sdk/utils/generate_yaml.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,12 @@
2020
from typing import Optional
2121
import typing
2222
import yaml
23-
import sys
2423
import os
25-
import argparse
2624
import uuid
2725
from kubernetes import client, config
2826
from .kube_api_helpers import _kube_api_error_handling
2927
from ..cluster.auth import api_config_handler, config_check
30-
from os import urandom
31-
from base64 import b64encode
32-
from urllib3.util import parse_url
28+
from mergedeep import merge, Strategy
3329

3430

3531
def read_template(template):
@@ -278,6 +274,16 @@ def write_user_yaml(user_yaml, output_file_name):
278274
print(f"Written to: {output_file_name}")
279275

280276

277+
def apply_head_template(cluster_yaml: dict, head_template: client.V1PodTemplateSpec):
278+
head = cluster_yaml.get("spec").get("headGroupSpec")
279+
merge(head["template"], head_template.to_dict(), strategy=Strategy.ADDITIVE)
280+
281+
282+
def apply_worker_template(cluster_yaml: dict, worker_template: client.V1PodTemplateSpec):
283+
worker = cluster_yaml.get("spec").get("workerGroupSpecs")[0]
284+
merge(worker["template"], worker_template.to_dict(), strategy=Strategy.ADDITIVE)
285+
286+
281287
def generate_appwrapper(
282288
name: str,
283289
namespace: str,
@@ -291,6 +297,8 @@ def generate_appwrapper(
291297
gpu: int,
292298
workers: int,
293299
template: str,
300+
head_template: client.V1PodTemplateSpec,
301+
worker_template: client.V1PodTemplateSpec,
294302
image: str,
295303
appwrapper: bool,
296304
env,
@@ -302,6 +310,12 @@ def generate_appwrapper(
302310
volume_mounts: list[client.V1VolumeMount],
303311
):
304312
cluster_yaml = read_template(template)
313+
314+
if head_template:
315+
apply_head_template(cluster_yaml, head_template)
316+
if worker_template:
317+
apply_worker_template(cluster_yaml, worker_template)
318+
305319
appwrapper_name, cluster_name = gen_names(name)
306320
update_names(cluster_yaml, cluster_name, namespace)
307321
update_nodes(

tests/unit_test.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
import re
2121
import uuid
2222

23-
from codeflare_sdk.cluster import cluster
24-
2523
parent = Path(__file__).resolve().parents[1]
2624
aw_dir = os.path.expanduser("~/.codeflare/resources/")
2725
sys.path.append(str(parent) + "/src")
@@ -69,17 +67,18 @@
6967
createClusterConfig,
7068
)
7169

72-
import codeflare_sdk.utils.kube_api_helpers
7370
from codeflare_sdk.utils.generate_yaml import (
7471
gen_names,
7572
is_openshift_cluster,
7673
)
7774

7875
import openshift
79-
from openshift.selector import Selector
8076
import ray
8177
import pytest
8278
import yaml
79+
80+
from kubernetes.client import V1PodTemplateSpec, V1PodSpec, V1Toleration
81+
8382
from unittest.mock import MagicMock
8483
from pytest_mock import MockerFixture
8584
from ray.job_submission import JobSubmissionClient
@@ -268,6 +267,41 @@ def test_config_creation():
268267
assert config.appwrapper == True
269268

270269

270+
def test_cluster_config_with_worker_template(mocker):
271+
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
272+
mocker.patch(
273+
"kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
274+
return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
275+
)
276+
277+
cluster = Cluster(ClusterConfiguration(
278+
name="unit-test-cluster",
279+
namespace="ns",
280+
num_workers=2,
281+
min_cpus=3,
282+
max_cpus=4,
283+
min_memory=5,
284+
max_memory=6,
285+
num_gpus=7,
286+
image="test/ray:2.20.0-py39-cu118",
287+
worker_template=V1PodTemplateSpec(
288+
spec=V1PodSpec(
289+
containers=[],
290+
tolerations=[V1Toleration(
291+
key="nvidia.com/gpu",
292+
operator="Exists",
293+
effect="NoSchedule",
294+
)],
295+
node_selector={
296+
"nvidia.com/gpu.present": "true",
297+
},
298+
)
299+
),
300+
))
301+
302+
assert cluster
303+
304+
271305
def test_cluster_creation(mocker):
272306
# Create AppWrapper containing a Ray Cluster with no local queue specified
273307
mocker.patch("kubernetes.client.ApisApi.get_api_versions")

0 commit comments

Comments
 (0)