Skip to content

Commit 658779c

Browse files
committed
Adding unit tests for apply
1 parent 554d7f2 commit 658779c

File tree

5 files changed

+142
-25
lines changed

5 files changed

+142
-25
lines changed

Diff for: src/codeflare_sdk/common/utils/unit_test_support.py

+58-3
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@
2727

2828

2929
def createClusterConfig():
30+
config = createClusterConfigWithNumWorkers()
31+
return config
32+
33+
34+
def createClusterConfigWithNumWorkers(num_workers=2):
3035
config = ClusterConfiguration(
3136
name="unit-test-cluster",
3237
namespace="ns",
33-
num_workers=2,
38+
num_workers=num_workers,
3439
worker_cpu_requests=3,
3540
worker_cpu_limits=4,
3641
worker_memory_requests=5,
@@ -41,13 +46,21 @@ def createClusterConfig():
4146
return config
4247

4348

44-
def createClusterWithConfig(mocker):
49+
def createClusterWithConfigAndNumWorkers(mocker, num_workers=2, dynamic_client=None):
4550
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
4651
mocker.patch(
4752
"kubernetes.client.CustomObjectsApi.get_cluster_custom_object",
4853
return_value={"spec": {"domain": "apps.cluster.awsroute.org"}},
4954
)
50-
cluster = Cluster(createClusterConfig())
55+
cluster = Cluster(createClusterConfigWithNumWorkers(num_workers))
56+
mocker.patch.object(cluster, "get_dynamic_client", return_value=dynamic_client)
57+
mocker.patch.object(cluster, "down", return_value=None)
58+
mocker.patch.object(cluster, "config_check", return_value=None)
59+
return cluster
60+
61+
62+
def createClusterWithConfig(mock_config):
63+
cluster = createClusterWithConfigAndNumWorkers(mock_config)
5164
return cluster
5265

5366

@@ -383,6 +396,48 @@ def mocked_ingress(port, cluster_name="unit-test-cluster", annotations: dict = N
383396
return mock_ingress
384397

385398

399+
# Global dictionary to maintain state in the mock
400+
cluster_state = {}
401+
402+
403+
# The mock side_effect function for server_side_apply
404+
def mock_server_side_apply(resource, body=None, name=None, namespace=None, **kwargs):
405+
# Simulate the behavior of server_side_apply:
406+
# Update a mock state that represents the cluster's current configuration.
407+
# Stores the state in a global dictionary for simplicity.
408+
409+
global cluster_state
410+
411+
if not resource or not body or not name or not namespace:
412+
raise ValueError("Missing required parameters for server_side_apply")
413+
414+
# Extract worker count from the body if it exists
415+
try:
416+
worker_count = (
417+
body["spec"]["workerGroupSpecs"][0]["replicas"]
418+
if "spec" in body and "workerGroupSpecs" in body["spec"]
419+
else None
420+
)
421+
except KeyError:
422+
worker_count = None
423+
424+
# Apply changes to the cluster_state mock
425+
cluster_state[name] = {
426+
"namespace": namespace,
427+
"worker_count": worker_count,
428+
"body": body,
429+
}
430+
431+
# Return a response that mimics the behavior of a successful apply
432+
return {
433+
"status": "success",
434+
"applied": True,
435+
"name": name,
436+
"namespace": namespace,
437+
"worker_count": worker_count,
438+
}
439+
440+
386441
@patch.dict("os.environ", {"NB_PREFIX": "test-prefix"})
387442
def create_cluster_all_config_params(mocker, cluster_name, is_appwrapper) -> Cluster:
388443
mocker.patch(

Diff for: src/codeflare_sdk/ray/cluster/cluster.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,14 @@ def __init__(self, config: ClusterConfiguration):
8888
if is_notebook():
8989
cluster_up_down_buttons(self)
9090

91+
def get_dynamic_client(self):
92+
"""Return a dynamic client, optionally mocked in tests."""
93+
return DynamicClient(get_api_client())
94+
95+
def config_check(self):
96+
"""Return a dynamic client, optionally mocked in tests."""
97+
return config_check()
98+
9199
@property
92100
def _client_headers(self):
93101
k8_client = get_api_client()
@@ -144,7 +152,7 @@ def up(self):
144152
the Kueue localqueue.
145153
"""
146154
# TODO: Add deprecation message in favor of apply()
147-
# print( "WARNING: The up() is planned for deprecation in favor of apply().")
155+
# print( "WARNING: The up() is planned for deprecation in favor of apply().")
148156

149157
# check if RayCluster CustomResourceDefinition exists if not throw RuntimeError
150158
self._throw_for_no_raycluster()
@@ -182,30 +190,27 @@ def up(self):
182190
except Exception as e: # pragma: no cover
183191
return _kube_api_error_handling(e)
184192

185-
186193
def apply(self, force=False):
187194
"""
188195
Applies the Cluster yaml using server-side apply.
189196
If 'force' is set to True, conflicts will be forced.
190197
"""
198+
# check if RayCluster CustomResourceDefinition exists if not throw RuntimeError
199+
self._throw_for_no_raycluster()
200+
namespace = self.config.namespace
191201
# Ensure Kubernetes configuration is loaded
192-
config_check()
193-
194-
# Create a dynamic client for interacting with custom resources
195-
dynamic_client = DynamicClient(get_api_client())
196-
197202
try:
203+
self.config_check()
198204
# Get the RayCluster custom resource definition
199-
api = dynamic_client.resources.get(
200-
api_version="ray.io/v1",
201-
kind="RayCluster"
205+
api = self.get_dynamic_client().resources.get(
206+
api_version="ray.io/v1", kind="RayCluster"
202207
)
203208
except Exception as e:
204209
raise RuntimeError("Failed to get RayCluster resource: " + str(e))
205210

206211
# Read the YAML file and parse it into a dictionary
207212
try:
208-
with open(self.resource_yaml, 'r') as f:
213+
with open(self.resource_yaml, "r") as f:
209214
resource_body = yaml.safe_load(f)
210215
except FileNotFoundError:
211216
raise RuntimeError(f"Resource YAML file '{self.resource_yaml}' not found.")
@@ -216,15 +221,14 @@ def apply(self, force=False):
216221
resource_name = resource_body.get("metadata", {}).get("name")
217222
if not resource_name:
218223
raise ValueError("The resource must have a 'metadata.name' field.")
219-
220224
try:
221225
# Use server-side apply
222226
resp = api.server_side_apply(
223227
body=resource_body,
224228
name=resource_name,
225229
namespace=self.config.namespace,
226230
field_manager="cluster-manager",
227-
force_conflicts=force # Allow forcing conflicts if needed
231+
force_conflicts=force, # Allow forcing conflicts if needed
228232
)
229233
print(f"Cluster '{self.config.name}' applied successfully.")
230234
except ApiException as e:
@@ -234,7 +238,9 @@ def apply(self, force=False):
234238
"To force the patch, set 'force=True' in the apply() method."
235239
)
236240
elif e.status == 404:
237-
print(f"Namespace '{self.config.namespace}' or resource '{resource_name}' not found. Verify the namespace or CRD.")
241+
print(
242+
f"Namespace '{self.config.namespace}' or resource '{resource_name}' not found. Verify the namespace or CRD."
243+
)
238244
else:
239245
raise RuntimeError(f"Failed to apply cluster: {e.reason}")
240246

@@ -266,7 +272,7 @@ def down(self):
266272
resource_name = self.config.name
267273
self._throw_for_no_raycluster()
268274
try:
269-
config_check()
275+
self.config_check()
270276
api_instance = client.CustomObjectsApi(get_api_client())
271277
if self.config.appwrapper:
272278
api_instance.delete_namespaced_custom_object(

Diff for: src/codeflare_sdk/ray/cluster/test_cluster.py

+42-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
)
2121
from codeflare_sdk.common.utils.unit_test_support import (
2222
createClusterWithConfig,
23+
createClusterWithConfigAndNumWorkers,
2324
arg_check_del_effect,
2425
ingress_retrieval,
2526
arg_check_apply_effect,
@@ -29,6 +30,7 @@
2930
get_obj_none,
3031
get_ray_obj_with_status,
3132
get_aw_obj_with_status,
33+
mock_server_side_apply,
3234
)
3335
from codeflare_sdk.ray.cluster.cluster import _is_openshift_cluster
3436
from pathlib import Path
@@ -67,11 +69,48 @@ def test_cluster_up_down(mocker):
6769
"kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
6870
return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
6971
)
70-
cluster = cluster = createClusterWithConfig(mocker)
72+
cluster = createClusterWithConfig(mocker)
7173
cluster.up()
7274
cluster.down()
7375

7476

77+
def test_cluster_apply_scale_up_scale_down(mocker):
78+
mock_dynamic_client = mocker.Mock()
79+
mocker.patch("codeflare_sdk.ray.cluster.cluster.Cluster._throw_for_no_raycluster")
80+
mocker.patch(
81+
"kubernetes.dynamic.DynamicClient.resources", new_callable=mocker.PropertyMock
82+
)
83+
mocker.patch(
84+
"codeflare_sdk.ray.cluster.cluster.Cluster.create_resource",
85+
return_value="./tests/test_cluster_yamls/ray/default-ray-cluster.yaml",
86+
)
87+
88+
# Initialize test
89+
initial_num_workers = 1
90+
scaled_up_num_workers = 2
91+
92+
# Step 1: Create cluster with initial workers
93+
cluster = createClusterWithConfigAndNumWorkers(
94+
mocker, initial_num_workers, dynamic_client=mock_dynamic_client
95+
)
96+
cluster.apply()
97+
98+
# Step 2: Scale up the cluster
99+
cluster = createClusterWithConfigAndNumWorkers(
100+
mocker, scaled_up_num_workers, dynamic_client=mock_dynamic_client
101+
)
102+
cluster.apply()
103+
104+
# Step 3: Scale down the cluster
105+
cluster = createClusterWithConfigAndNumWorkers(
106+
mocker, initial_num_workers, dynamic_client=mock_dynamic_client
107+
)
108+
cluster.apply()
109+
110+
# Tear down
111+
cluster.down()
112+
113+
75114
def test_cluster_up_down_no_mcad(mocker):
76115
mocker.patch("codeflare_sdk.ray.cluster.cluster.Cluster._throw_for_no_raycluster")
77116
mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
@@ -117,7 +156,7 @@ def test_cluster_uris(mocker):
117156
"kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
118157
return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
119158
)
120-
cluster = cluster = createClusterWithConfig(mocker)
159+
cluster = createClusterWithConfig(mocker)
121160
mocker.patch(
122161
"kubernetes.client.NetworkingV1Api.list_namespaced_ingress",
123162
return_value=ingress_retrieval(
@@ -159,7 +198,7 @@ def ray_addr(self, *args):
159198
"kubernetes.client.CustomObjectsApi.list_namespaced_custom_object",
160199
return_value=get_local_queue("kueue.x-k8s.io", "v1beta1", "ns", "localqueues"),
161200
)
162-
cluster = cluster = createClusterWithConfig(mocker)
201+
cluster = createClusterWithConfig(mocker)
163202
mocker.patch(
164203
"ray.job_submission.JobSubmissionClient._check_connection_and_version_with_url",
165204
return_value="None",

Diff for: tests/e2e/cluster_apply_test.py renamed to tests/e2e/cluster_apply_kind_test.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_cluster_apply(self):
6262
updated_config = ClusterConfiguration(
6363
name=cluster_name,
6464
namespace=namespace,
65-
num_workers=3,
65+
num_workers=2,
6666
head_cpu_requests="500m",
6767
head_cpu_limits="1",
6868
head_memory_requests="1Gi",
@@ -82,12 +82,14 @@ def test_cluster_apply(self):
8282
# Wait for the updated cluster to be ready
8383
cluster.wait_ready()
8484
updated_status = cluster.status()
85-
assert updated_status["ready"], f"Cluster {cluster_name} is not ready after update: {updated_status}"
85+
assert updated_status[
86+
"ready"
87+
], f"Cluster {cluster_name} is not ready after update: {updated_status}"
8688

8789
# Verify the cluster is updated
8890
updated_ray_cluster = get_ray_cluster(cluster_name, namespace)
8991
assert (
90-
updated_ray_cluster["spec"]["workerGroupSpecs"][0]["replicas"] == 3
92+
updated_ray_cluster["spec"]["workerGroupSpecs"][0]["replicas"] == 2
9193
), "Worker count was not updated"
9294

9395
# Clean up
@@ -152,4 +154,3 @@ def test_apply_invalid_update(self):
152154

153155
# Clean up
154156
cluster.down()
155-

Diff for: tests/e2e/support.py

+16
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
)
1212

1313

14+
def get_ray_cluster(cluster_name, namespace):
15+
api = client.CustomObjectsApi()
16+
try:
17+
return api.get_namespaced_custom_object(
18+
group="ray.io",
19+
version="v1",
20+
namespace=namespace,
21+
plural="rayclusters",
22+
name=cluster_name,
23+
)
24+
except client.exceptions.ApiException as e:
25+
if e.status == 404:
26+
return None
27+
raise
28+
29+
1430
def get_ray_image():
1531
default_ray_image = "quay.io/modh/ray@sha256:0d715f92570a2997381b7cafc0e224cfa25323f18b9545acfd23bc2b71576d06"
1632
return os.getenv("RAY_IMAGE", default_ray_image)

0 commit comments

Comments
 (0)