Skip to content

Commit f4639dd

Browse files
committed
Made Kueue the default queueing strategy
Updated oauth test to have mcad=True Changed .codeflare/appwrappers to .codeflare/resources Addressed comments & added SUSPENDED status Review changes & list_cluster functions Updated tests and load_components Update tests, Rebase
1 parent bd49ef7 commit f4639dd

13 files changed

+515
-23
lines changed

Diff for: src/codeflare_sdk/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
RayCluster,
1313
AppWrapper,
1414
get_cluster,
15+
list_all_queued,
16+
list_all_clusters,
1517
)
1618

1719
from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient

Diff for: src/codeflare_sdk/cluster/__init__.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
AppWrapper,
1414
)
1515

16-
from .cluster import Cluster, ClusterConfiguration, get_cluster
16+
from .cluster import (
17+
Cluster,
18+
ClusterConfiguration,
19+
get_cluster,
20+
list_all_queued,
21+
list_all_clusters,
22+
)
1723

1824
from .awload import AWManager

Diff for: src/codeflare_sdk/cluster/cluster.py

+32-11
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def create_app_wrapper(self):
189189
dispatch_priority = self.config.dispatch_priority
190190
write_to_file = self.config.write_to_file
191191
verify_tls = self.config.verify_tls
192+
local_queue = self.config.local_queue
192193
return generate_appwrapper(
193194
name=name,
194195
namespace=namespace,
@@ -213,6 +214,7 @@ def create_app_wrapper(self):
213214
priority_val=priority_val,
214215
write_to_file=write_to_file,
215216
verify_tls=verify_tls,
217+
local_queue=local_queue,
216218
)
217219

218220
# creates a new cluster with the provided or default spec
@@ -319,6 +321,9 @@ def status(
319321
# check the ray cluster status
320322
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
321323
if cluster:
324+
if cluster.status == RayClusterStatus.SUSPENDED:
325+
ready = False
326+
status = CodeFlareClusterStatus.SUSPENDED
322327
if cluster.status == RayClusterStatus.UNKNOWN:
323328
ready = False
324329
status = CodeFlareClusterStatus.STARTING
@@ -588,17 +593,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
588593
return clusters
589594

590595

591-
def list_all_queued(namespace: str, print_to_console: bool = True):
596+
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
592597
"""
593-
Returns (and prints by default) a list of all currently queued-up AppWrappers
598+
Returns (and prints by default) a list of all currently queued-up Ray Clusters or AppWrappers
594599
in a given namespace.
595600
"""
596-
app_wrappers = _get_app_wrappers(
597-
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
598-
)
599-
if print_to_console:
600-
pretty_print.print_app_wrappers_status(app_wrappers)
601-
return app_wrappers
601+
if mcad:
602+
resources = _get_app_wrappers(
603+
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
604+
)
605+
if print_to_console:
606+
pretty_print.print_app_wrappers_status(resources)
607+
else:
608+
resources = _get_ray_clusters(
609+
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
610+
)
611+
if print_to_console:
612+
pretty_print.print_ray_clusters_status(resources)
613+
return resources
602614

603615

604616
def get_current_namespace(): # pragma: no cover
@@ -798,7 +810,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
798810
return None
799811

800812

801-
def _get_ray_clusters(namespace="default") -> List[RayCluster]:
813+
def _get_ray_clusters(
814+
namespace="default", filter: Optional[List[RayClusterStatus]] = None
815+
) -> List[RayCluster]:
802816
list_of_clusters = []
803817
try:
804818
config_check()
@@ -812,8 +826,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
812826
except Exception as e: # pragma: no cover
813827
return _kube_api_error_handling(e)
814828

815-
for rc in rcs["items"]:
816-
list_of_clusters.append(_map_to_ray_cluster(rc))
829+
# Get a list of RCs with the filter if it is passed to the function
830+
if filter is not None:
831+
for rc in rcs["items"]:
832+
ray_cluster = _map_to_ray_cluster(rc)
833+
if filter and ray_cluster.status in filter:
834+
list_of_clusters.append(ray_cluster)
835+
else:
836+
for rc in rcs["items"]:
837+
list_of_clusters.append(_map_to_ray_cluster(rc))
817838
return list_of_clusters
818839

819840

Diff for: src/codeflare_sdk/cluster/config.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ClusterConfiguration:
4646
num_gpus: int = 0
4747
template: str = f"{dir}/templates/base-template.yaml"
4848
instascale: bool = False
49-
mcad: bool = True
49+
mcad: bool = False
5050
envs: dict = field(default_factory=dict)
5151
image: str = ""
5252
local_interactive: bool = False
@@ -60,3 +60,5 @@ def __post_init__(self):
6060
print(
6161
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
6262
)
63+
64+
local_queue: str = None

Diff for: src/codeflare_sdk/cluster/model.py

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class RayClusterStatus(Enum):
3232
UNHEALTHY = "unhealthy"
3333
FAILED = "failed"
3434
UNKNOWN = "unknown"
35+
SUSPENDED = "suspended"
3536

3637

3738
class AppWrapperStatus(Enum):
@@ -59,6 +60,7 @@ class CodeFlareClusterStatus(Enum):
5960
QUEUEING = 4
6061
FAILED = 5
6162
UNKNOWN = 6
63+
SUSPENDED = 7
6264

6365

6466
@dataclass

Diff for: src/codeflare_sdk/utils/generate_yaml.py

+57-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
(in the cluster sub-module) for AppWrapper generation.
1818
"""
1919

20+
from typing import Optional
2021
import typing
2122
import yaml
2223
import sys
@@ -460,29 +461,79 @@ def _create_oauth_sidecar_object(
460461
)
461462

462463

463-
def write_components(user_yaml: dict, output_file_name: str):
464+
def get_default_kueue_name(namespace: str):
465+
# If the local queue is set, use it. Otherwise, try to use the default queue.
466+
try:
467+
config_check()
468+
api_instance = client.CustomObjectsApi(api_config_handler())
469+
local_queues = api_instance.list_namespaced_custom_object(
470+
group="kueue.x-k8s.io",
471+
version="v1beta1",
472+
namespace=namespace,
473+
plural="localqueues",
474+
)
475+
except Exception as e: # pragma: no cover
476+
return _kube_api_error_handling(e)
477+
for lq in local_queues["items"]:
478+
if (
479+
"annotations" in lq["metadata"]
480+
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
481+
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
482+
== "true"
483+
):
484+
return lq["metadata"]["name"]
485+
raise ValueError(
486+
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
487+
)
488+
489+
490+
def write_components(
491+
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
492+
):
464493
# Create the directory if it doesn't exist
465494
directory_path = os.path.dirname(output_file_name)
466495
if not os.path.exists(directory_path):
467496
os.makedirs(directory_path)
468497

469498
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
470499
open(output_file_name, "w").close()
500+
lq_name = local_queue or get_default_kueue_name(namespace)
471501
with open(output_file_name, "a") as outfile:
472502
for component in components:
473503
if "generictemplate" in component:
504+
if (
505+
"workload.codeflare.dev/appwrapper"
506+
in component["generictemplate"]["metadata"]["labels"]
507+
):
508+
del component["generictemplate"]["metadata"]["labels"][
509+
"workload.codeflare.dev/appwrapper"
510+
]
511+
labels = component["generictemplate"]["metadata"]["labels"]
512+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
474513
outfile.write("---\n")
475514
yaml.dump(
476515
component["generictemplate"], outfile, default_flow_style=False
477516
)
478517
print(f"Written to: {output_file_name}")
479518

480519

481-
def load_components(user_yaml: dict, name: str):
520+
def load_components(
521+
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
522+
):
482523
component_list = []
483524
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
525+
lq_name = local_queue or get_default_kueue_name(namespace)
484526
for component in components:
485527
if "generictemplate" in component:
528+
if (
529+
"workload.codeflare.dev/appwrapper"
530+
in component["generictemplate"]["metadata"]["labels"]
531+
):
532+
del component["generictemplate"]["metadata"]["labels"][
533+
"workload.codeflare.dev/appwrapper"
534+
]
535+
labels = component["generictemplate"]["metadata"]["labels"]
536+
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
486537
component_list.append(component["generictemplate"])
487538

488539
resources = "---\n" + "---\n".join(
@@ -523,6 +574,7 @@ def generate_appwrapper(
523574
priority_val: int,
524575
write_to_file: bool,
525576
verify_tls: bool,
577+
local_queue: Optional[str],
526578
):
527579
user_yaml = read_template(template)
528580
appwrapper_name, cluster_name = gen_names(name)
@@ -575,18 +627,18 @@ def generate_appwrapper(
575627
if is_openshift_cluster():
576628
enable_openshift_oauth(user_yaml, cluster_name, namespace)
577629

578-
directory_path = os.path.expanduser("~/.codeflare/appwrapper/")
630+
directory_path = os.path.expanduser("~/.codeflare/resources/")
579631
outfile = os.path.join(directory_path, appwrapper_name + ".yaml")
580632

581633
if write_to_file:
582634
if mcad:
583635
write_user_appwrapper(user_yaml, outfile)
584636
else:
585-
write_components(user_yaml, outfile)
637+
write_components(user_yaml, outfile, namespace, local_queue)
586638
return outfile
587639
else:
588640
if mcad:
589641
user_yaml = load_appwrapper(user_yaml, name)
590642
else:
591-
user_yaml = load_components(user_yaml, name)
643+
user_yaml = load_components(user_yaml, name, namespace, local_queue)
592644
return user_yaml

Diff for: src/codeflare_sdk/utils/pretty_print.py

+24
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
5656
console.print(Panel.fit(table))
5757

5858

59+
def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
60+
if not app_wrappers:
61+
print_no_resources_found()
62+
return # shortcircuit
63+
64+
console = Console()
65+
table = Table(
66+
box=box.ASCII_DOUBLE_HEAD,
67+
title="[bold] :rocket: Cluster Queue Status :rocket:",
68+
)
69+
table.add_column("Name", style="cyan", no_wrap=True)
70+
table.add_column("Status", style="magenta")
71+
72+
for app_wrapper in app_wrappers:
73+
name = app_wrapper.name
74+
status = app_wrapper.status.value
75+
if starting:
76+
status += " (starting)"
77+
table.add_row(name, status)
78+
table.add_row("") # empty row for spacing
79+
80+
console.print(Panel.fit(table))
81+
82+
5983
def print_cluster_status(cluster: RayCluster):
6084
"Pretty prints the status of a passed-in cluster"
6185
if not cluster:

Diff for: tests/e2e/mnist_raycluster_sdk_oauth_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk_oauth(self):
5252
instascale=False,
5353
image=ray_image,
5454
write_to_file=True,
55+
mcad=True,
5556
)
5657
)
5758

Diff for: tests/e2e/mnist_raycluster_sdk_test.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk(self):
5252
instascale=False,
5353
image=ray_image,
5454
write_to_file=True,
55+
mcad=True,
5556
)
5657
)
5758

Diff for: tests/e2e/start_ray_cluster.py

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
num_gpus=0,
2323
instascale=False,
2424
image=ray_image,
25+
mcad=True,
2526
)
2627
)
2728

Diff for: tests/test-case-no-mcad.yamls

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ metadata:
66
sdk.codeflare.dev/local_interactive: 'False'
77
labels:
88
controller-tools.k8s.io: '1.0'
9-
workload.codeflare.dev/appwrapper: unit-test-cluster-ray
9+
kueue.x-k8s.io/queue-name: local-queue-default
1010
name: unit-test-cluster-ray
1111
namespace: ns
1212
spec:

0 commit comments

Comments
 (0)