Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit bc762a0

Browse files
committedApr 4, 2024·
Made Kueue the default queueing strategy
Updated oauth test to have mcad=True Changed .codeflare/appwrappers to .codeflare/resources Addressed comments & added SUSPENDED status Review changes & list_cluster functions Updated tests and load_components Update tests, Rebase
1 parent 403cca6 commit bc762a0

13 files changed

+515
-23
lines changed
 

‎src/codeflare_sdk/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
RayCluster,
1313
AppWrapper,
1414
get_cluster,
15+
list_all_queued,
16+
list_all_clusters,
1517
)
1618

1719
from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient

‎src/codeflare_sdk/cluster/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
AppWrapper,
1414
)
1515

16-
from .cluster import Cluster, ClusterConfiguration, get_cluster
16+
from .cluster import (
17+
Cluster,
18+
ClusterConfiguration,
19+
get_cluster,
20+
list_all_queued,
21+
list_all_clusters,
22+
)
1723

1824
from .awload import AWManager

‎src/codeflare_sdk/cluster/cluster.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def create_app_wrapper(self):
191191
ingress_options = self.config.ingress_options
192192
write_to_file = self.config.write_to_file
193193
verify_tls = self.config.verify_tls
194+
local_queue = self.config.local_queue
194195
return generate_appwrapper(
195196
name=name,
196197
namespace=namespace,
@@ -217,6 +218,7 @@ def create_app_wrapper(self):
217218
ingress_options=ingress_options,
218219
write_to_file=write_to_file,
219220
verify_tls=verify_tls,
221+
local_queue=local_queue,
220222
)
221223

222224
# creates a new cluster with the provided or default spec
@@ -323,6 +325,9 @@ def status(
323325
# check the ray cluster status
324326
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
325327
if cluster:
328+
if cluster.status == RayClusterStatus.SUSPENDED:
329+
ready = False
330+
status = CodeFlareClusterStatus.SUSPENDED
326331
if cluster.status == RayClusterStatus.UNKNOWN:
327332
ready = False
328333
status = CodeFlareClusterStatus.STARTING
@@ -601,17 +606,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
601606
return clusters
602607

603608

604-
def list_all_queued(namespace: str, print_to_console: bool = True):
609+
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
605610
"""
606-
Returns (and prints by default) a list of all currently queued-up AppWrappers
611+
Returns (and prints by default) a list of all currently queued-up Ray Clusters or AppWrappers
607612
in a given namespace.
608613
"""
609-
app_wrappers = _get_app_wrappers(
610-
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
611-
)
612-
if print_to_console:
613-
pretty_print.print_app_wrappers_status(app_wrappers)
614-
return app_wrappers
614+
if mcad:
615+
resources = _get_app_wrappers(
616+
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
617+
)
618+
if print_to_console:
619+
pretty_print.print_app_wrappers_status(resources)
620+
else:
621+
resources = _get_ray_clusters(
622+
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
623+
)
624+
if print_to_console:
625+
pretty_print.print_ray_clusters_status(resources)
626+
return resources
615627

616628

617629
def get_current_namespace(): # pragma: no cover
@@ -898,7 +910,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
898910
return None
899911

900912

901-
def _get_ray_clusters(namespace="default") -> List[RayCluster]:
913+
def _get_ray_clusters(
914+
namespace="default", filter: Optional[List[RayClusterStatus]] = None
915+
) -> List[RayCluster]:
902916
list_of_clusters = []
903917
try:
904918
config_check()
@@ -912,8 +926,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
912926
except Exception as e: # pragma: no cover
913927
return _kube_api_error_handling(e)
914928

915-
for rc in rcs["items"]:
916-
list_of_clusters.append(_map_to_ray_cluster(rc))
929+
# Get a list of RCs with the filter if it is passed to the function
930+
if filter is not None:
931+
for rc in rcs["items"]:
932+
ray_cluster = _map_to_ray_cluster(rc)
933+
if filter and ray_cluster.status in filter:
934+
list_of_clusters.append(ray_cluster)
935+
else:
936+
for rc in rcs["items"]:
937+
list_of_clusters.append(_map_to_ray_cluster(rc))
917938
return list_of_clusters
918939

919940

‎src/codeflare_sdk/cluster/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ClusterConfiguration:
4646
num_gpus: int = 0
4747
template: str = f"{dir}/templates/base-template.yaml"
4848
instascale: bool = False
49-
mcad: bool = True
49+
mcad: bool = False
5050
envs: dict = field(default_factory=dict)
5151
image: str = ""
5252
local_interactive: bool = False
@@ -62,3 +62,5 @@ def __post_init__(self):
6262
print(
6363
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
6464
)
65+
66+
local_queue: str = None

‎src/codeflare_sdk/cluster/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class RayClusterStatus(Enum):
3232
UNHEALTHY = "unhealthy"
3333
FAILED = "failed"
3434
UNKNOWN = "unknown"
35+
SUSPENDED = "suspended"
3536

3637

3738
class AppWrapperStatus(Enum):
@@ -59,6 +60,7 @@ class CodeFlareClusterStatus(Enum):
5960
QUEUEING = 4
6061
FAILED = 5
6162
UNKNOWN = 6
63+
SUSPENDED = 7
6264

6365

6466
@dataclass

‎src/codeflare_sdk/utils/generate_yaml.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
(in the cluster sub-module) for AppWrapper generation.
1818
"""
1919

20+
from typing import Optional
2021
import typing
2122
import yaml
2223
import sys
@@ -646,29 +647,79 @@ def _create_oauth_sidecar_object(
646647
)
647648

648649

649-
def write_components(user_yaml: dict, output_file_name: str):
650+
def get_default_kueue_name(namespace: str):
651+
# If the local queue is set, use it. Otherwise, try to use the default queue.
652+
try:
653+
config_check()
654+
api_instance = client.CustomObjectsApi(api_config_handler())
655+
local_queues = api_instance.list_namespaced_custom_object(
656+
group="kueue.x-k8s.io",
657+
version="v1beta1",
658+
namespace=namespace,
659+
plural="localqueues",
660+
)
661+
except Exception as e: # pragma: no cover
662+
return _kube_api_error_handling(e)
663+
for lq in local_queues["items"]:
664+
if (
665+
"annotations" in lq["metadata"]
666+
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
667+
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
668+
== "true"
669+
):
670+
return lq["metadata"]["name"]
671+
raise ValueError(
672+
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
673+
)
674+
675+
676+
def write_components(
677+
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
678+
):
650679
# Create the directory if it doesn't exist
651680
directory_path = os.path.dirname(output_file_name)
652681
if not os.path.exists(directory_path):
653682
os.makedirs(directory_path)
654683

655684
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
656685
open(output_file_name, "w").close()
686+
lq_name = local_queue or get_default_kueue_name(namespace)
657687
with open(output_file_name, "a") as outfile:
658688
for component in components:
659689
if "generictemplate" in component:
690+
if (
691+
"workload.codeflare.dev/appwrapper"
692+
in component["generictemplate"]["metadata"]["labels"]
693+
):
694+
del component["generictemplate"]["metadata"]["labels"][
695+
"workload.codeflare.dev/appwrapper"
696+
]
697+
labels = component["generictemplate"]["metadata"]["labels"]
698+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
660699
outfile.write("---\n")
661700
yaml.dump(
662701
component["generictemplate"], outfile, default_flow_style=False
663702
)
664703
print(f"Written to: {output_file_name}")
665704

666705

667-
def load_components(user_yaml: dict, name: str):
706+
def load_components(
707+
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
708+
):
668709
component_list = []
669710
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
711+
lq_name = local_queue or get_default_kueue_name(namespace)
670712
for component in components:
671713
if "generictemplate" in component:
714+
if (
715+
"workload.codeflare.dev/appwrapper"
716+
in component["generictemplate"]["metadata"]["labels"]
717+
):
718+
del component["generictemplate"]["metadata"]["labels"][
719+
"workload.codeflare.dev/appwrapper"
720+
]
721+
labels = component["generictemplate"]["metadata"]["labels"]
722+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
672723
component_list.append(component["generictemplate"])
673724

674725
resources = "---\n" + "---\n".join(
@@ -711,6 +762,7 @@ def generate_appwrapper(
711762
ingress_options: dict,
712763
write_to_file: bool,
713764
verify_tls: bool,
765+
local_queue: Optional[str],
714766
):
715767
user_yaml = read_template(template)
716768
appwrapper_name, cluster_name = gen_names(name)
@@ -771,18 +823,18 @@ def generate_appwrapper(
771823
if is_openshift_cluster():
772824
enable_openshift_oauth(user_yaml, cluster_name, namespace)
773825

774-
directory_path = os.path.expanduser("~/.codeflare/appwrapper/")
826+
directory_path = os.path.expanduser("~/.codeflare/resources/")
775827
outfile = os.path.join(directory_path, appwrapper_name + ".yaml")
776828

777829
if write_to_file:
778830
if mcad:
779831
write_user_appwrapper(user_yaml, outfile)
780832
else:
781-
write_components(user_yaml, outfile)
833+
write_components(user_yaml, outfile, namespace, local_queue)
782834
return outfile
783835
else:
784836
if mcad:
785837
user_yaml = load_appwrapper(user_yaml, name)
786838
else:
787-
user_yaml = load_components(user_yaml, name)
839+
user_yaml = load_components(user_yaml, name, namespace, local_queue)
788840
return user_yaml

‎src/codeflare_sdk/utils/pretty_print.py

+24
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
5656
console.print(Panel.fit(table))
5757

5858

59+
def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
60+
if not app_wrappers:
61+
print_no_resources_found()
62+
return # shortcircuit
63+
64+
console = Console()
65+
table = Table(
66+
box=box.ASCII_DOUBLE_HEAD,
67+
title="[bold] :rocket: Cluster Queue Status :rocket:",
68+
)
69+
table.add_column("Name", style="cyan", no_wrap=True)
70+
table.add_column("Status", style="magenta")
71+
72+
for app_wrapper in app_wrappers:
73+
name = app_wrapper.name
74+
status = app_wrapper.status.value
75+
if starting:
76+
status += " (starting)"
77+
table.add_row(name, status)
78+
table.add_row("") # empty row for spacing
79+
80+
console.print(Panel.fit(table))
81+
82+
5983
def print_cluster_status(cluster: RayCluster):
6084
"Pretty prints the status of a passed-in cluster"
6185
if not cluster:

‎tests/e2e/mnist_raycluster_sdk_oauth_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk_oauth(self):
5252
instascale=False,
5353
image=ray_image,
5454
write_to_file=True,
55+
mcad=True,
5556
)
5657
)
5758

‎tests/e2e/mnist_raycluster_sdk_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def run_mnist_raycluster_sdk(self):
7171
image=ray_image,
7272
ingress_options=ingress_options,
7373
write_to_file=True,
74+
mcad=True,
7475
)
7576
)
7677

‎tests/e2e/start_ray_cluster.py

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
instascale=False,
3939
image=ray_image,
4040
ingress_options=ingress_options,
41+
mcad=True,
4142
)
4243
)
4344

‎tests/test-case-no-mcad.yamls

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ metadata:
77
sdk.codeflare.dev/local_interactive: 'False'
88
labels:
99
controller-tools.k8s.io: '1.0'
10-
workload.codeflare.dev/appwrapper: unit-test-cluster-ray
10+
kueue.x-k8s.io/queue-name: local-queue-default
1111
name: unit-test-cluster-ray
1212
namespace: ns
1313
spec:

0 commit comments

Comments
 (0)
Please sign in to comment.