Skip to content

Use Kueue as default #470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/codeflare_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
RayCluster,
AppWrapper,
get_cluster,
list_all_queued,
list_all_clusters,
)

from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient
Expand Down
8 changes: 7 additions & 1 deletion src/codeflare_sdk/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
AppWrapper,
)

from .cluster import Cluster, ClusterConfiguration, get_cluster
from .cluster import (
Cluster,
ClusterConfiguration,
get_cluster,
list_all_queued,
list_all_clusters,
)

from .awload import AWManager
43 changes: 32 additions & 11 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def create_app_wrapper(self):
dispatch_priority = self.config.dispatch_priority
write_to_file = self.config.write_to_file
verify_tls = self.config.verify_tls
local_queue = self.config.local_queue
return generate_appwrapper(
name=name,
namespace=namespace,
Expand All @@ -213,6 +214,7 @@ def create_app_wrapper(self):
priority_val=priority_val,
write_to_file=write_to_file,
verify_tls=verify_tls,
local_queue=local_queue,
)

# creates a new cluster with the provided or default spec
Expand Down Expand Up @@ -319,6 +321,9 @@ def status(
# check the ray cluster status
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
if cluster:
if cluster.status == RayClusterStatus.SUSPENDED:
ready = False
status = CodeFlareClusterStatus.SUSPENDED
if cluster.status == RayClusterStatus.UNKNOWN:
ready = False
status = CodeFlareClusterStatus.STARTING
Expand Down Expand Up @@ -588,17 +593,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
return clusters


def list_all_queued(namespace: str, print_to_console: bool = True):
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
"""
Returns (and prints by default) a list of all currently queued-up AppWrappers
Returns (and prints by default) a list of all currently queued-up Ray Clusters
in a given namespace.
"""
app_wrappers = _get_app_wrappers(
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
)
if print_to_console:
pretty_print.print_app_wrappers_status(app_wrappers)
return app_wrappers
if mcad:
resources = _get_app_wrappers(
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
)
if print_to_console:
pretty_print.print_app_wrappers_status(resources)
else:
resources = _get_ray_clusters(
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
)
if print_to_console:
pretty_print.print_ray_clusters_status(resources)
return resources


def get_current_namespace(): # pragma: no cover
Expand Down Expand Up @@ -798,7 +810,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
return None


def _get_ray_clusters(namespace="default") -> List[RayCluster]:
def _get_ray_clusters(
namespace="default", filter: Optional[List[RayClusterStatus]] = None
) -> List[RayCluster]:
list_of_clusters = []
try:
config_check()
Expand All @@ -812,8 +826,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)

for rc in rcs["items"]:
list_of_clusters.append(_map_to_ray_cluster(rc))
# Get a list of RCs with the filter if it is passed to the function
if filter is not None:
for rc in rcs["items"]:
ray_cluster = _map_to_ray_cluster(rc)
if filter and ray_cluster.status in filter:
list_of_clusters.append(ray_cluster)
else:
for rc in rcs["items"]:
list_of_clusters.append(_map_to_ray_cluster(rc))
return list_of_clusters


Expand Down
4 changes: 3 additions & 1 deletion src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ClusterConfiguration:
num_gpus: int = 0
template: str = f"{dir}/templates/base-template.yaml"
instascale: bool = False
mcad: bool = True
mcad: bool = False
envs: dict = field(default_factory=dict)
image: str = ""
local_interactive: bool = False
Expand All @@ -60,3 +60,5 @@ def __post_init__(self):
print(
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
)

local_queue: str = None
2 changes: 2 additions & 0 deletions src/codeflare_sdk/cluster/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class RayClusterStatus(Enum):
UNHEALTHY = "unhealthy"
FAILED = "failed"
UNKNOWN = "unknown"
SUSPENDED = "suspended"


class AppWrapperStatus(Enum):
Expand Down Expand Up @@ -59,6 +60,7 @@ class CodeFlareClusterStatus(Enum):
QUEUEING = 4
FAILED = 5
UNKNOWN = 6
SUSPENDED = 7


@dataclass
Expand Down
62 changes: 57 additions & 5 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
(in the cluster sub-module) for AppWrapper generation.
"""

from typing import Optional
import typing
import yaml
import sys
Expand Down Expand Up @@ -460,29 +461,79 @@ def _create_oauth_sidecar_object(
)


def write_components(user_yaml: dict, output_file_name: str):
def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
namespace=namespace,
plural="localqueues",
)
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)
for lq in local_queues["items"]:
if (
"annotations" in lq["metadata"]
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
== "true"
):
return lq["metadata"]["name"]
raise ValueError(
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
)


def write_components(
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
):
# Create the directory if it doesn't exist
directory_path = os.path.dirname(output_file_name)
if not os.path.exists(directory_path):
os.makedirs(directory_path)

components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
open(output_file_name, "w").close()
lq_name = local_queue or get_default_kueue_name(namespace)
with open(output_file_name, "a") as outfile:
for component in components:
if "generictemplate" in component:
if (
"workload.codeflare.dev/appwrapper"
in component["generictemplate"]["metadata"]["labels"]
):
del component["generictemplate"]["metadata"]["labels"][
"workload.codeflare.dev/appwrapper"
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
outfile.write("---\n")
yaml.dump(
component["generictemplate"], outfile, default_flow_style=False
)
print(f"Written to: {output_file_name}")


def load_components(user_yaml: dict, name: str):
def load_components(
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
):
component_list = []
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
lq_name = local_queue or get_default_kueue_name(namespace)
for component in components:
if "generictemplate" in component:
if (
"workload.codeflare.dev/appwrapper"
in component["generictemplate"]["metadata"]["labels"]
):
del component["generictemplate"]["metadata"]["labels"][
"workload.codeflare.dev/appwrapper"
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
component_list.append(component["generictemplate"])

resources = "---\n" + "---\n".join(
Expand Down Expand Up @@ -523,6 +574,7 @@ def generate_appwrapper(
priority_val: int,
write_to_file: bool,
verify_tls: bool,
local_queue: Optional[str],
):
user_yaml = read_template(template)
appwrapper_name, cluster_name = gen_names(name)
Expand Down Expand Up @@ -575,18 +627,18 @@ def generate_appwrapper(
if is_openshift_cluster():
enable_openshift_oauth(user_yaml, cluster_name, namespace)

directory_path = os.path.expanduser("~/.codeflare/appwrapper/")
directory_path = os.path.expanduser("~/.codeflare/resources/")
outfile = os.path.join(directory_path, appwrapper_name + ".yaml")

if write_to_file:
if mcad:
write_user_appwrapper(user_yaml, outfile)
else:
write_components(user_yaml, outfile)
write_components(user_yaml, outfile, namespace, local_queue)
return outfile
else:
if mcad:
user_yaml = load_appwrapper(user_yaml, name)
else:
user_yaml = load_components(user_yaml, name)
user_yaml = load_components(user_yaml, name, namespace, local_queue)
return user_yaml
24 changes: 24 additions & 0 deletions src/codeflare_sdk/utils/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
console.print(Panel.fit(table))


def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
if not app_wrappers:
print_no_resources_found()
return # shortcircuit

console = Console()
table = Table(
box=box.ASCII_DOUBLE_HEAD,
title="[bold] :rocket: Cluster Queue Status :rocket:",
)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Status", style="magenta")

for app_wrapper in app_wrappers:
name = app_wrapper.name
status = app_wrapper.status.value
if starting:
status += " (starting)"
table.add_row(name, status)
table.add_row("") # empty row for spacing

console.print(Panel.fit(table))


def print_cluster_status(cluster: RayCluster):
"Pretty prints the status of a passed-in cluster"
if not cluster:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/mnist_raycluster_sdk_oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk_oauth(self):
instascale=False,
image=ray_image,
write_to_file=True,
mcad=True,
)
)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/mnist_raycluster_sdk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk(self):
instascale=False,
image=ray_image,
write_to_file=True,
mcad=True,
)
)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/start_ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
num_gpus=0,
instascale=False,
image=ray_image,
mcad=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test-case-no-mcad.yamls
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ metadata:
sdk.codeflare.dev/local_interactive: 'False'
labels:
controller-tools.k8s.io: '1.0'
workload.codeflare.dev/appwrapper: unit-test-cluster-ray
kueue.x-k8s.io/queue-name: local-queue-default
name: unit-test-cluster-ray
namespace: ns
spec:
Expand Down
Loading
Loading