Skip to content

Commit eb83ae5

Browse files
committed
adding validation for local_queue provided in cluster config
1 parent 179fd75 commit eb83ae5

File tree

2 files changed

+54
-17
lines changed

2 files changed

+54
-17
lines changed

Diff for: src/codeflare_sdk/utils/generate_yaml.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,26 @@ def get_default_kueue_name(namespace: str):
308308
)
309309

310310

311+
def local_queue_exists(namespace: str, local_queue_name: str):
312+
# get all local queues in the namespace
313+
try:
314+
config_check()
315+
api_instance = client.CustomObjectsApi(api_config_handler())
316+
local_queues = api_instance.list_namespaced_custom_object(
317+
group="kueue.x-k8s.io",
318+
version="v1beta1",
319+
namespace=namespace,
320+
plural="localqueues",
321+
)
322+
except Exception as e: # pragma: no cover
323+
return _kube_api_error_handling(e)
324+
# check if local queue with the name provided in cluster config exists
325+
for lq in local_queues["items"]:
326+
if lq["metadata"]["name"] == local_queue_name:
327+
return True
328+
return False
329+
330+
311331
def write_components(
312332
user_yaml: dict,
313333
output_file_name: str,
@@ -324,24 +344,29 @@ def write_components(
324344
open(output_file_name, "w").close()
325345
lq_name = local_queue or get_default_kueue_name(namespace)
326346
cluster_labels = labels
327-
with open(output_file_name, "a") as outfile:
328-
for component in components:
329-
if "generictemplate" in component:
330-
if (
331-
"workload.codeflare.dev/appwrapper"
332-
in component["generictemplate"]["metadata"]["labels"]
333-
):
334-
del component["generictemplate"]["metadata"]["labels"][
347+
if local_queue_exists(namespace, lq_name):
348+
with open(output_file_name, "a") as outfile:
349+
for component in components:
350+
if "generictemplate" in component:
351+
if (
335352
"workload.codeflare.dev/appwrapper"
336-
]
337-
labels = component["generictemplate"]["metadata"]["labels"]
338-
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
339-
labels.update(cluster_labels)
340-
outfile.write("---\n")
341-
yaml.dump(
342-
component["generictemplate"], outfile, default_flow_style=False
343-
)
344-
print(f"Written to: {output_file_name}")
353+
in component["generictemplate"]["metadata"]["labels"]
354+
):
355+
del component["generictemplate"]["metadata"]["labels"][
356+
"workload.codeflare.dev/appwrapper"
357+
]
358+
labels = component["generictemplate"]["metadata"]["labels"]
359+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
360+
labels.update(cluster_labels)
361+
outfile.write("---\n")
362+
yaml.dump(
363+
component["generictemplate"], outfile, default_flow_style=False
364+
)
365+
print(f"Written to: {output_file_name}")
366+
else:
367+
raise ValueError(
368+
"local_queue provided does not exist. Please provide the correct local_queue name in Cluster Configuration"
369+
)
345370

346371

347372
def load_components(
@@ -355,6 +380,10 @@ def load_components(
355380
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
356381
lq_name = local_queue or get_default_kueue_name(namespace)
357382
cluster_labels = labels
383+
if not local_queue_exists(namespace, lq_name):
384+
raise ValueError(
385+
"local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration"
386+
)
358387
for component in components:
359388
if "generictemplate" in component:
360389
if (

Diff for: tests/unit_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ def test_cluster_creation_no_mcad_local_queue(mocker):
344344
"kubernetes.client.CustomObjectsApi.get_cluster_custom_object",
345345
return_value={"spec": {"domain": "apps.cluster.awsroute.org"}},
346346
)
347+
mocker.patch(
348+
"kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
349+
return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
350+
)
347351
config = createClusterConfig()
348352
config.name = "unit-test-cluster-ray"
349353
config.mcad = False
@@ -3015,6 +3019,10 @@ def test_cluster_throw_for_no_raycluster(mocker: MockerFixture):
30153019
"codeflare_sdk.utils.generate_yaml.get_default_kueue_name",
30163020
return_value="default",
30173021
)
3022+
mocker.patch(
3023+
"codeflare_sdk.utils.generate_yaml.local_queue_exists",
3024+
return_value="true",
3025+
)
30183026

30193027
def throw_if_getting_raycluster(group, version, namespace, plural):
30203028
if plural == "rayclusters":

0 commit comments

Comments
 (0)