Skip to content

Commit 07b4c6d

Browse files
committed
Refactor: kueue module
1 parent 993cea1 commit 07b4c6d

File tree

4 files changed

+91
-75
lines changed

4 files changed

+91
-75
lines changed

Diff for: src/codeflare_sdk/common/kueue/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .kueue import *

Diff for: src/codeflare_sdk/common/kueue/kueue.py

+78
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2024 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
from codeflare_sdk.common import _kube_api_error_handling
17+
from codeflare_sdk.common.kubernetes_cluster.auth import config_check, get_api_client
18+
from kubernetes import client
19+
from kubernetes.client.exceptions import ApiException
20+
21+
22+
def get_default_kueue_name(namespace: str):
23+
# If the local queue is set, use it. Otherwise, try to use the default queue.
24+
try:
25+
config_check()
26+
api_instance = client.CustomObjectsApi(get_api_client())
27+
local_queues = api_instance.list_namespaced_custom_object(
28+
group="kueue.x-k8s.io",
29+
version="v1beta1",
30+
namespace=namespace,
31+
plural="localqueues",
32+
)
33+
except ApiException as e: # pragma: no cover
34+
if e.status == 404 or e.status == 403:
35+
return
36+
else:
37+
return _kube_api_error_handling(e)
38+
for lq in local_queues["items"]:
39+
if (
40+
"annotations" in lq["metadata"]
41+
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
42+
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
43+
== "true"
44+
):
45+
return lq["metadata"]["name"]
46+
47+
48+
def local_queue_exists(namespace: str, local_queue_name: str):
49+
# get all local queues in the namespace
50+
try:
51+
config_check()
52+
api_instance = client.CustomObjectsApi(get_api_client())
53+
local_queues = api_instance.list_namespaced_custom_object(
54+
group="kueue.x-k8s.io",
55+
version="v1beta1",
56+
namespace=namespace,
57+
plural="localqueues",
58+
)
59+
except Exception as e: # pragma: no cover
60+
return _kube_api_error_handling(e)
61+
# check if local queue with the name provided in cluster config exists
62+
for lq in local_queues["items"]:
63+
if lq["metadata"]["name"] == local_queue_name:
64+
return True
65+
return False
66+
67+
68+
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
69+
lq_name = local_queue or get_default_kueue_name(namespace)
70+
if lq_name == None:
71+
return
72+
elif not local_queue_exists(namespace, lq_name):
73+
raise ValueError(
74+
"local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration"
75+
)
76+
if not "labels" in item["metadata"]:
77+
item["metadata"]["labels"] = {}
78+
item["metadata"]["labels"].update({"kueue.x-k8s.io/queue-name": lq_name})

Diff for: src/codeflare_sdk/ray/cluster/generate_yaml.py

+1-60
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import uuid
2626
from kubernetes import client
2727
from ...common import _kube_api_error_handling
28+
from codeflare_sdk.common.kueue.kueue import add_queue_label
2829
from ...common.kubernetes_cluster.auth import (
2930
get_api_client,
3031
config_check,
3132
)
32-
from kubernetes.client.exceptions import ApiException
3333
import codeflare_sdk
3434

3535

@@ -229,65 +229,6 @@ def del_from_list_by_name(l: list, target: typing.List[str]) -> list:
229229
return [x for x in l if x["name"] not in target]
230230

231231

232-
def get_default_kueue_name(namespace: str):
233-
# If the local queue is set, use it. Otherwise, try to use the default queue.
234-
try:
235-
config_check()
236-
api_instance = client.CustomObjectsApi(get_api_client())
237-
local_queues = api_instance.list_namespaced_custom_object(
238-
group="kueue.x-k8s.io",
239-
version="v1beta1",
240-
namespace=namespace,
241-
plural="localqueues",
242-
)
243-
except ApiException as e: # pragma: no cover
244-
if e.status == 404 or e.status == 403:
245-
return
246-
else:
247-
return _kube_api_error_handling(e)
248-
for lq in local_queues["items"]:
249-
if (
250-
"annotations" in lq["metadata"]
251-
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
252-
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
253-
== "true"
254-
):
255-
return lq["metadata"]["name"]
256-
257-
258-
def local_queue_exists(namespace: str, local_queue_name: str):
259-
# get all local queues in the namespace
260-
try:
261-
config_check()
262-
api_instance = client.CustomObjectsApi(get_api_client())
263-
local_queues = api_instance.list_namespaced_custom_object(
264-
group="kueue.x-k8s.io",
265-
version="v1beta1",
266-
namespace=namespace,
267-
plural="localqueues",
268-
)
269-
except Exception as e: # pragma: no cover
270-
return _kube_api_error_handling(e)
271-
# check if local queue with the name provided in cluster config exists
272-
for lq in local_queues["items"]:
273-
if lq["metadata"]["name"] == local_queue_name:
274-
return True
275-
return False
276-
277-
278-
def add_queue_label(item: dict, namespace: str, local_queue: Optional[str]):
279-
lq_name = local_queue or get_default_kueue_name(namespace)
280-
if lq_name == None:
281-
return
282-
elif not local_queue_exists(namespace, lq_name):
283-
raise ValueError(
284-
"local_queue provided does not exist or is not in this namespace. Please provide the correct local_queue name in Cluster Configuration"
285-
)
286-
if not "labels" in item["metadata"]:
287-
item["metadata"]["labels"] = {}
288-
item["metadata"]["labels"].update({"kueue.x-k8s.io/queue-name": lq_name})
289-
290-
291232
def augment_labels(item: dict, labels: dict):
292233
if not "labels" in item["metadata"]:
293234
item["metadata"]["labels"] = {}

Diff for: tests/unit_test.py

+11-15
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
get_package_and_version,
7474
)
7575

76+
from codeflare_sdk.common.kueue import *
77+
7678
import codeflare_sdk.common.kubernetes_cluster.kube_api_helpers
7779
from codeflare_sdk.ray.cluster.generate_yaml import (
7880
gen_names,
@@ -968,7 +970,7 @@ def test_ray_details(mocker, capsys):
968970
return_value="",
969971
)
970972
mocker.patch(
971-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
973+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
972974
return_value="true",
973975
)
974976
cf = Cluster(
@@ -2007,7 +2009,7 @@ def test_get_cluster_openshift(mocker):
20072009
]
20082010
mocker.patch("kubernetes.client.ApisApi", return_value=mock_api)
20092011
mocker.patch(
2010-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2012+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
20112013
return_value="true",
20122014
)
20132015

@@ -2042,7 +2044,7 @@ def custom_side_effect(group, version, namespace, plural, **kwargs):
20422044
],
20432045
)
20442046
mocker.patch(
2045-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2047+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
20462048
return_value="true",
20472049
)
20482050

@@ -2085,7 +2087,7 @@ def test_get_cluster(mocker):
20852087
return_value=ingress_retrieval(cluster_name="quicktest", client_ing=True),
20862088
)
20872089
mocker.patch(
2088-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2090+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
20892091
return_value="true",
20902092
)
20912093
cluster = get_cluster("quicktest")
@@ -2123,7 +2125,7 @@ def test_get_cluster_no_mcad(mocker):
21232125
return_value=ingress_retrieval(cluster_name="quicktest", client_ing=True),
21242126
)
21252127
mocker.patch(
2126-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2128+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
21272129
return_value="true",
21282130
)
21292131
cluster = get_cluster("quicktest")
@@ -2359,7 +2361,7 @@ def test_cluster_status(mocker):
23592361
mocker.patch("kubernetes.client.ApisApi.get_api_versions")
23602362
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
23612363
mocker.patch(
2362-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2364+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
23632365
return_value="true",
23642366
)
23652367
fake_aw = AppWrapper("test", AppWrapperStatus.FAILED)
@@ -2456,13 +2458,7 @@ def test_wait_ready(mocker, capsys):
24562458
)
24572459
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
24582460
mocker.patch(
2459-
"codeflare_sdk.ray.cluster.cluster._app_wrapper_status", return_value=None
2460-
)
2461-
mocker.patch(
2462-
"codeflare_sdk.ray.cluster.cluster._ray_cluster_status", return_value=None
2463-
)
2464-
mocker.patch(
2465-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2461+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
24662462
return_value="true",
24672463
)
24682464
mocker.patch.object(
@@ -2694,11 +2690,11 @@ def test_cluster_throw_for_no_raycluster(mocker: MockerFixture):
26942690
return_value="opendatahub",
26952691
)
26962692
mocker.patch(
2697-
"codeflare_sdk.ray.cluster.generate_yaml.get_default_kueue_name",
2693+
"codeflare_sdk.common.kueue.kueue.get_default_kueue_name",
26982694
return_value="default",
26992695
)
27002696
mocker.patch(
2701-
"codeflare_sdk.ray.cluster.generate_yaml.local_queue_exists",
2697+
"codeflare_sdk.common.kueue.kueue.local_queue_exists",
27022698
return_value="true",
27032699
)
27042700

0 commit comments

Comments
 (0)