Skip to content

Commit f694e61

Browse files
committed
Review changes & list_cluster functions
1 parent 5bd5848 commit f694e61

File tree

6 files changed

+400
-53
lines changed

6 files changed

+400
-53
lines changed

Diff for: src/codeflare_sdk/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
CodeFlareClusterStatus,
1212
RayCluster,
1313
AppWrapper,
14+
list_all_queued,
15+
list_all_clusters,
1416
)
1517

1618
from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient

Diff for: src/codeflare_sdk/cluster/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
AppWrapper,
1414
)
1515

16-
from .cluster import Cluster, ClusterConfiguration
16+
from .cluster import Cluster, ClusterConfiguration, list_all_queued, list_all_clusters
1717

1818
from .awload import AWManager

Diff for: src/codeflare_sdk/cluster/cluster.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -641,17 +641,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
641641
return clusters
642642

643643

644-
def list_all_queued(namespace: str, print_to_console: bool = True):
644+
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
645645
"""
646-
Returns (and prints by default) a list of all currently queued-up AppWrappers
646+
Returns (and prints by default) a list of all currently queued-up Ray Clusters or AppWrappers
647647
in a given namespace.
648648
"""
649-
app_wrappers = _get_app_wrappers(
650-
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
651-
)
652-
if print_to_console:
653-
pretty_print.print_app_wrappers_status(app_wrappers)
654-
return app_wrappers
649+
if mcad:
650+
resources = _get_app_wrappers(
651+
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
652+
)
653+
if print_to_console:
654+
pretty_print.print_app_wrappers_status(resources)
655+
else:
656+
resources = _get_ray_clusters(
657+
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
658+
)
659+
if print_to_console:
660+
pretty_print.print_ray_clusters_status(resources)
661+
return resources
655662

656663

657664
def get_current_namespace(): # pragma: no cover
@@ -856,7 +863,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
856863
return None
857864

858865

859-
def _get_ray_clusters(namespace="default") -> List[RayCluster]:
866+
def _get_ray_clusters(
867+
namespace="default", filter: Optional[List[RayClusterStatus]] = None
868+
) -> List[RayCluster]:
860869
list_of_clusters = []
861870
try:
862871
config_check()
@@ -870,8 +879,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
870879
except Exception as e: # pragma: no cover
871880
return _kube_api_error_handling(e)
872881

873-
for rc in rcs["items"]:
874-
list_of_clusters.append(_map_to_ray_cluster(rc))
882+
# Get a list of RCs with the filter if it is passed to the function
883+
if filter is not None:
884+
for rc in rcs["items"]:
885+
ray_cluster = _map_to_ray_cluster(rc)
886+
if filter and ray_cluster.status in filter:
887+
list_of_clusters.append(ray_cluster)
888+
else:
889+
for rc in rcs["items"]:
890+
list_of_clusters.append(_map_to_ray_cluster(rc))
875891
return list_of_clusters
876892

877893

Diff for: src/codeflare_sdk/utils/generate_yaml.py

+28-38
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
(in the cluster sub-module) for AppWrapper generation.
1818
"""
1919

20-
import typing
20+
from typing import List, Optional
2121
import yaml
2222
import sys
2323
import os
@@ -467,7 +467,7 @@ def enable_local_interactive(resources, cluster_name, namespace, ingress_domain)
467467
][0].get("command")[2] = command
468468

469469

470-
def del_from_list_by_name(l: list, target: typing.List[str]) -> list:
470+
def del_from_list_by_name(l: list, target: List[str]) -> list:
471471
return [x for x in l if x["name"] not in target]
472472

473473

@@ -622,39 +622,34 @@ def _create_oauth_sidecar_object(
622622
)
623623

624624

625-
def get_default_kueue_name(local_queue: str, namespace: str):
625+
def get_default_kueue_name(namespace: str):
626626
# If the local queue is set, use it. Otherwise, try to use the default queue.
627-
if local_queue is not None:
628-
return local_queue
629-
else:
630-
try:
631-
config_check()
632-
api_instance = client.CustomObjectsApi(api_config_handler())
633-
local_queues = api_instance.list_namespaced_custom_object(
634-
group="kueue.x-k8s.io",
635-
version="v1beta1",
636-
namespace=namespace,
637-
plural="localqueues",
638-
)
639-
except Exception as e: # pragma: no cover
640-
return _kube_api_error_handling(e)
641-
for lq in local_queues["items"]:
642-
if (
643-
"annotations" in lq["metadata"]
644-
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
645-
and lq["metadata"]["annotations"][
646-
"kueue.x-k8s.io/default-queue"
647-
].lower()
648-
== "true"
649-
):
650-
return lq["metadata"]["name"]
651-
raise ValueError(
652-
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
627+
try:
628+
config_check()
629+
api_instance = client.CustomObjectsApi(api_config_handler())
630+
local_queues = api_instance.list_namespaced_custom_object(
631+
group="kueue.x-k8s.io",
632+
version="v1beta1",
633+
namespace=namespace,
634+
plural="localqueues",
653635
)
636+
except Exception as e: # pragma: no cover
637+
return _kube_api_error_handling(e)
638+
for lq in local_queues["items"]:
639+
if (
640+
"annotations" in lq["metadata"]
641+
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
642+
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
643+
== "true"
644+
):
645+
return lq["metadata"]["name"]
646+
raise ValueError(
647+
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
648+
)
654649

655650

656651
def write_components(
657-
user_yaml: dict, output_file_name: str, namespace: str, local_queue: str
652+
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
658653
):
659654
# Create the directory if it doesn't exist
660655
directory_path = os.path.dirname(output_file_name)
@@ -663,6 +658,7 @@ def write_components(
663658

664659
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
665660
open(output_file_name, "w").close()
661+
lq_name = local_queue or get_default_kueue_name(namespace)
666662
with open(output_file_name, "a") as outfile:
667663
for component in components:
668664
if "generictemplate" in component:
@@ -674,13 +670,7 @@ def write_components(
674670
"workload.codeflare.dev/appwrapper"
675671
]
676672
labels = component["generictemplate"]["metadata"]["labels"]
677-
labels.update(
678-
{
679-
"kueue.x-k8s.io/queue-name": get_default_kueue_name(
680-
local_queue, namespace
681-
)
682-
}
683-
)
673+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
684674
outfile.write("---\n")
685675
yaml.dump(
686676
component["generictemplate"], outfile, default_flow_style=False
@@ -713,7 +703,7 @@ def generate_appwrapper(
713703
openshift_oauth: bool,
714704
ingress_domain: str,
715705
ingress_options: dict,
716-
local_queue: str,
706+
local_queue: Optional[str],
717707
):
718708
user_yaml = read_template(template)
719709
appwrapper_name, cluster_name = gen_names(name)

Diff for: src/codeflare_sdk/utils/pretty_print.py

+24
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
5656
console.print(Panel.fit(table))
5757

5858

59+
def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
60+
if not app_wrappers:
61+
print_no_resources_found()
62+
return # shortcircuit
63+
64+
console = Console()
65+
table = Table(
66+
box=box.ASCII_DOUBLE_HEAD,
67+
title="[bold] :rocket: Cluster Queue Status :rocket:",
68+
)
69+
table.add_column("Name", style="cyan", no_wrap=True)
70+
table.add_column("Status", style="magenta")
71+
72+
for app_wrapper in app_wrappers:
73+
name = app_wrapper.name
74+
status = app_wrapper.status.value
75+
if starting:
76+
status += " (starting)"
77+
table.add_row(name, status)
78+
table.add_row("") # empty row for spacing
79+
80+
console.print(Panel.fit(table))
81+
82+
5983
def print_cluster_status(cluster: RayCluster):
6084
"Pretty prints the status of a passed-in cluster"
6185
if not cluster:

0 commit comments

Comments
 (0)