diff --git a/src/codeflare_sdk/common/utils/unit_test_support.py b/src/codeflare_sdk/common/utils/unit_test_support.py index 82f301a24..373283b82 100644 --- a/src/codeflare_sdk/common/utils/unit_test_support.py +++ b/src/codeflare_sdk/common/utils/unit_test_support.py @@ -22,6 +22,7 @@ import yaml from pathlib import Path from kubernetes import client +from kubernetes.client import V1Toleration from unittest.mock import patch parent = Path(__file__).resolve().parents[4] # project directory @@ -427,8 +428,18 @@ def create_cluster_all_config_params(mocker, cluster_name, is_appwrapper) -> Clu head_memory_requests=12, head_memory_limits=16, head_extended_resource_requests={"nvidia.com/gpu": 1, "intel.com/gpu": 2}, + head_tolerations=[ + V1Toleration( + key="key1", operator="Equal", value="value1", effect="NoSchedule" + ) + ], worker_cpu_requests=4, worker_cpu_limits=8, + worker_tolerations=[ + V1Toleration( + key="key2", operator="Equal", value="value2", effect="NoSchedule" + ) + ], num_workers=10, worker_memory_requests=12, worker_memory_limits=16, diff --git a/src/codeflare_sdk/ray/cluster/build_ray_cluster.py b/src/codeflare_sdk/ray/cluster/build_ray_cluster.py index a08f3f732..215ac32e2 100644 --- a/src/codeflare_sdk/ray/cluster/build_ray_cluster.py +++ b/src/codeflare_sdk/ray/cluster/build_ray_cluster.py @@ -16,7 +16,7 @@ This sub-module exists primarily to be used internally by the Cluster object (in the cluster sub-module) for RayCluster/AppWrapper generation. """ -from typing import Union, Tuple, Dict +from typing import List, Union, Tuple, Dict from ...common import _kube_api_error_handling from ...common.kubernetes_cluster import get_api_client, config_check from kubernetes.client.exceptions import ApiException @@ -40,6 +40,7 @@ V1PodTemplateSpec, V1PodSpec, V1LocalObjectReference, + V1Toleration, ) import yaml @@ -139,7 +140,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"): "resources": head_resources, }, "template": { - "spec": get_pod_spec(cluster, [get_head_container_spec(cluster)]) + "spec": get_pod_spec( + cluster, + [get_head_container_spec(cluster)], + cluster.config.head_tolerations, + ) }, }, "workerGroupSpecs": [ @@ -154,7 +159,11 @@ def build_ray_cluster(cluster: "codeflare_sdk.ray.cluster.Cluster"): "resources": worker_resources, }, "template": V1PodTemplateSpec( - spec=get_pod_spec(cluster, [get_worker_container_spec(cluster)]) + spec=get_pod_spec( + cluster, + [get_worker_container_spec(cluster)], + cluster.config.worker_tolerations, + ) ), } ], @@ -243,14 +252,21 @@ def update_image(image) -> str: return image -def get_pod_spec(cluster: "codeflare_sdk.ray.cluster.Cluster", containers): +def get_pod_spec( + cluster: "codeflare_sdk.ray.cluster.Cluster", + containers: List, + tolerations: List[V1Toleration], +) -> V1PodSpec: """ The get_pod_spec() function generates a V1PodSpec for the head/worker containers """ + pod_spec = V1PodSpec( containers=containers, volumes=generate_custom_storage(cluster.config.volumes, VOLUMES), + tolerations=tolerations or None, ) + if cluster.config.image_pull_secrets != []: pod_spec.image_pull_secrets = generate_image_pull_secrets(cluster) diff --git a/src/codeflare_sdk/ray/cluster/config.py b/src/codeflare_sdk/ray/cluster/config.py index 7a78e7303..ab64be839 100644 --- a/src/codeflare_sdk/ray/cluster/config.py +++ b/src/codeflare_sdk/ray/cluster/config.py @@ -22,7 +22,7 @@ import warnings from dataclasses import dataclass, field, fields from typing import Dict, List, Optional, Union, get_args, get_origin -from kubernetes.client import V1Volume, V1VolumeMount +from kubernetes.client import V1Toleration, V1Volume, V1VolumeMount dir = pathlib.Path(__file__).parent.parent.resolve() @@ -58,6 +58,8 @@ class ClusterConfiguration: The number of GPUs to allocate to the head node. (Deprecated, use head_extended_resource_requests) head_extended_resource_requests: A dictionary of extended resource requests for the head node. ex: {"nvidia.com/gpu": 1} + head_tolerations: + List of tolerations for head nodes. min_cpus: The minimum number of CPUs to allocate to each worker. max_cpus: @@ -70,6 +72,8 @@ class ClusterConfiguration: The maximum amount of memory to allocate to each worker. num_gpus: The number of GPUs to allocate to each worker. (Deprecated, use worker_extended_resource_requests) + worker_tolerations: + List of tolerations for worker nodes. appwrapper: A boolean indicating whether to use an AppWrapper. envs: @@ -110,6 +114,7 @@ class ClusterConfiguration: head_extended_resource_requests: Dict[str, Union[str, int]] = field( default_factory=dict ) + head_tolerations: Optional[List[V1Toleration]] = None worker_cpu_requests: Union[int, str] = 1 worker_cpu_limits: Union[int, str] = 1 min_cpus: Optional[Union[int, str]] = None # Deprecating @@ -120,6 +125,7 @@ class ClusterConfiguration: min_memory: Optional[Union[int, str]] = None # Deprecating max_memory: Optional[Union[int, str]] = None # Deprecating num_gpus: Optional[int] = None # Deprecating + worker_tolerations: Optional[List[V1Toleration]] = None appwrapper: bool = False envs: Dict[str, str] = field(default_factory=dict) image: str = "" @@ -272,7 +278,10 @@ def check_type(value, expected_type): if origin_type is Union: return any(check_type(value, union_type) for union_type in args) if origin_type is list: - return all(check_type(elem, args[0]) for elem in value) + if value is not None: + return all(check_type(elem, args[0]) for elem in (value or [])) + else: + return True if origin_type is dict: return all( check_type(k, args[0]) and check_type(v, args[1]) diff --git a/tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml b/tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml index e0ecc75d3..0977d659d 100644 --- a/tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml +++ b/tests/test_cluster_yamls/appwrapper/unit-test-all-params.yaml @@ -99,6 +99,11 @@ spec: imagePullSecrets: - name: secret1 - name: secret2 + tolerations: + - effect: NoSchedule + key: key1 + operator: Equal + value: value1 volumes: - emptyDir: sizeLimit: 500Gi @@ -185,6 +190,11 @@ spec: imagePullSecrets: - name: secret1 - name: secret2 + tolerations: + - effect: NoSchedule + key: key2 + operator: Equal + value: value2 volumes: - emptyDir: sizeLimit: 500Gi diff --git a/tests/test_cluster_yamls/ray/unit-test-all-params.yaml b/tests/test_cluster_yamls/ray/unit-test-all-params.yaml index e743e9fe0..188319ab1 100644 --- a/tests/test_cluster_yamls/ray/unit-test-all-params.yaml +++ b/tests/test_cluster_yamls/ray/unit-test-all-params.yaml @@ -90,6 +90,11 @@ spec: imagePullSecrets: - name: secret1 - name: secret2 + tolerations: + - effect: NoSchedule + key: key1 + operator: Equal + value: value1 volumes: - emptyDir: sizeLimit: 500Gi @@ -176,6 +181,11 @@ spec: imagePullSecrets: - name: secret1 - name: secret2 + tolerations: + - effect: NoSchedule + key: key2 + operator: Equal + value: value2 volumes: - emptyDir: sizeLimit: 500Gi