Skip to content

Commit 88b2518

Browse files
authored
remove: DDPJobDefinition from SDK (#498)
* remove: DDPJobDefinition and update tests * add: address comments
1 parent a06cf4f commit 88b2518

9 files changed

+67
-667
lines changed

Diff for: src/codeflare_sdk/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@
1616
list_all_clusters,
1717
)
1818

19-
from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient
19+
from .job import RayJobClient
2020

2121
from .utils import generate_cert

Diff for: src/codeflare_sdk/cluster/cluster.py

-18
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
from time import sleep
2222
from typing import List, Optional, Tuple, Dict
2323

24-
import openshift as oc
2524
from kubernetes import config
2625
from ray.job_submission import JobSubmissionClient
27-
import urllib3
2826

2927
from .auth import config_check, api_config_handler
3028
from ..utils import pretty_print
@@ -58,8 +56,6 @@ class Cluster:
5856
Note that currently, the underlying implementation is a Ray cluster.
5957
"""
6058

61-
torchx_scheduler = "ray"
62-
6359
def __init__(self, config: ClusterConfiguration):
6460
"""
6561
Create the resource cluster object by passing in a ClusterConfiguration
@@ -477,20 +473,6 @@ def job_logs(self, job_id: str) -> str:
477473
"""
478474
return self.job_client.get_job_logs(job_id)
479475

480-
def torchx_config(
481-
self, working_dir: str = None, requirements: str = None
482-
) -> Dict[str, str]:
483-
dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host
484-
to_return = {
485-
"cluster_name": self.config.name,
486-
"dashboard_address": dashboard_address,
487-
}
488-
if working_dir:
489-
to_return["working_dir"] = working_dir
490-
if requirements:
491-
to_return["requirements"] = requirements
492-
return to_return
493-
494476
def from_k8_cluster_object(
495477
rc,
496478
mcad=True,

Diff for: src/codeflare_sdk/job/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1 @@
1-
from .jobs import JobDefinition, Job, DDPJobDefinition, DDPJob
2-
31
from .ray_jobs import RayJobClient

Diff for: src/codeflare_sdk/job/jobs.py

-207
This file was deleted.

Diff for: tests/e2e/mnist_raycluster_sdk_oauth_test.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from time import sleep
44

5-
from torchx.specs.api import AppState, is_terminal
6-
75
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication
8-
from codeflare_sdk.job.jobs import DDPJobDefinition
6+
from codeflare_sdk.job import RayJobClient
97

108
import pytest
119

@@ -79,7 +77,7 @@ def assert_jobsubmit_withoutLogin(self, cluster):
7977
"entrypoint": "python mnist.py",
8078
"runtime_env": {
8179
"working_dir": "./tests/e2e/",
82-
"pip": "mnist_pip_requirements.txt",
80+
"pip": "./tests/e2e/mnist_pip_requirements.txt",
8381
},
8482
}
8583
try:
@@ -98,19 +96,26 @@ def assert_jobsubmit_withoutLogin(self, cluster):
9896

9997
def assert_jobsubmit_withlogin(self, cluster):
10098
self.assert_appwrapper_exists()
101-
jobdef = DDPJobDefinition(
102-
name="mnist",
103-
script="./tests/e2e/mnist.py",
104-
scheduler_args={"requirements": "./tests/e2e/mnist_pip_requirements.txt"},
99+
auth_token = run_oc_command(["whoami", "--show-token=true"])
100+
ray_dashboard = cluster.cluster_dashboard_uri()
101+
header = {"Authorization": f"Bearer {auth_token}"}
102+
client = RayJobClient(address=ray_dashboard, headers=header, verify=True)
103+
104+
# Submit the job
105+
submission_id = client.submit_job(
106+
entrypoint="python mnist.py",
107+
runtime_env={
108+
"working_dir": "./tests/e2e/",
109+
"pip": "mnist_pip_requirements.txt",
110+
},
105111
)
106-
job = jobdef.submit(cluster)
107-
112+
print(f"Submitted job with ID: {submission_id}")
108113
done = False
109114
time = 0
110115
timeout = 900
111116
while not done:
112-
status = job.status()
113-
if is_terminal(status.state):
117+
status = client.get_job_status(submission_id)
118+
if status.is_terminal():
114119
break
115120
if not done:
116121
print(status)
@@ -119,11 +124,12 @@ def assert_jobsubmit_withlogin(self, cluster):
119124
sleep(5)
120125
time += 5
121126

122-
print(job.status())
123-
self.assert_job_completion(status)
127+
logs = client.get_job_logs(submission_id)
128+
print(logs)
124129

125-
print(job.logs())
130+
self.assert_job_completion(status)
126131

132+
client.delete_job(submission_id)
127133
cluster.down()
128134

129135
def assert_appwrapper_exists(self):
@@ -144,9 +150,9 @@ def assert_appwrapper_exists(self):
144150
assert False
145151

146152
def assert_job_completion(self, status):
147-
if status.state == AppState.SUCCEEDED:
148-
print(f"Job has completed: '{status.state}'")
153+
if status == "SUCCEEDED":
154+
print(f"Job has completed: '{status}'")
149155
assert True
150156
else:
151-
print(f"Job has completed: '{status.state}'")
157+
print(f"Job has completed: '{status}'")
152158
assert False

0 commit comments

Comments
 (0)