Skip to content

Commit d0c1ea5

Browse files
Update support functions to handle ray job api operation using tls verification and added functions for asserting ray job status
1 parent d44e319 commit d0c1ea5

File tree

1 file changed

+101
-9
lines changed

1 file changed

+101
-9
lines changed

support/ray_cluster_client.go

Lines changed: 101 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ package support
1818

1919
import (
2020
"bytes"
21+
"crypto/tls"
2122
"encoding/json"
2223
"fmt"
2324
"io"
2425
"net/http"
2526
"net/url"
27+
"time"
2628
)
2729

2830
type RayJobSetup struct {
@@ -45,20 +47,32 @@ type RayJobLogsResponse struct {
4547
Logs string `json:"logs"`
4648
}
4749

50+
type RayClusterClientConfig struct {
51+
SkipTlsVerification bool
52+
}
53+
4854
var _ RayClusterClient = (*rayClusterClient)(nil)
4955

5056
type rayClusterClient struct {
51-
endpoint url.URL
57+
endpoint url.URL
58+
httpClient *http.Client
59+
authHeader string
5260
}
5361

5462
type RayClusterClient interface {
5563
CreateJob(job *RayJobSetup) (*RayJobResponse, error)
5664
GetJobDetails(jobID string) (*RayJobDetailsResponse, error)
5765
GetJobLogs(jobID string) (string, error)
66+
GetAllJobsData() ([]map[string]interface{}, error)
67+
WaitForJobStatus(jobID string) (string, error)
5868
}
5969

60-
func NewRayClusterClient(dashboardEndpoint url.URL) RayClusterClient {
61-
return &rayClusterClient{endpoint: dashboardEndpoint}
70+
func NewRayClusterClient(dashboardEndpoint url.URL, config RayClusterClientConfig, authHeader string) RayClusterClient {
71+
tr := &http.Transport{
72+
TLSClientConfig: &tls.Config{InsecureSkipVerify: config.SkipTlsVerification},
73+
Proxy: http.ProxyFromEnvironment,
74+
}
75+
return &rayClusterClient{endpoint: dashboardEndpoint, httpClient: &http.Client{Transport: tr}, authHeader: authHeader}
6276
}
6377

6478
func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobResponse, err error) {
@@ -68,7 +82,8 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
6882
}
6983

7084
createJobURL := client.endpoint.String() + "/api/jobs/"
71-
resp, err := http.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
85+
86+
resp, err := client.httpClient.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
7287
if err != nil {
7388
return
7489
}
@@ -86,11 +101,51 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
86101
return
87102
}
88103

104+
func (client *rayClusterClient) GetAllJobsData() ([]map[string]interface{}, error) {
105+
getAllJobsDetailsURL := client.endpoint.String() + "/api/jobs/"
106+
107+
req, err := http.NewRequest(http.MethodGet, getAllJobsDetailsURL, nil)
108+
if err != nil {
109+
return nil, err
110+
}
111+
if client.authHeader != "" {
112+
req.Header.Set("Authorization", "Bearer "+client.authHeader)
113+
}
114+
resp, err := client.httpClient.Do(req)
115+
if err != nil {
116+
return nil, err
117+
}
118+
defer resp.Body.Close()
119+
if resp.StatusCode == 503 {
120+
return nil, fmt.Errorf("service unavailable")
121+
}
122+
body, err := io.ReadAll(resp.Body)
123+
if err != nil {
124+
return nil, err
125+
}
126+
127+
var result []map[string]interface{}
128+
err = json.Unmarshal(body, &result)
129+
if err != nil {
130+
return nil, err
131+
}
132+
return result, nil
133+
}
134+
89135
func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDetailsResponse, err error) {
90136
getJobDetailsURL := client.endpoint.String() + "/api/jobs/" + jobID
91-
resp, err := http.Get(getJobDetailsURL)
137+
138+
req, err := http.NewRequest(http.MethodGet, getJobDetailsURL, nil)
92139
if err != nil {
93-
return
140+
return nil, err
141+
}
142+
if client.authHeader != "" {
143+
req.Header.Set("Authorization", "Bearer "+client.authHeader)
144+
}
145+
146+
resp, err := client.httpClient.Do(req)
147+
if err != nil {
148+
return nil, err
94149
}
95150

96151
respData, err := io.ReadAll(resp.Body)
@@ -108,14 +163,21 @@ func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDet
108163

109164
func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error) {
110165
getJobLogsURL := client.endpoint.String() + "/api/jobs/" + jobID + "/logs"
111-
resp, err := http.Get(getJobLogsURL)
166+
req, err := http.NewRequest(http.MethodGet, getJobLogsURL, nil)
112167
if err != nil {
113-
return
168+
return "", err
169+
}
170+
if client.authHeader != "" {
171+
req.Header.Set("Authorization", "Bearer "+client.authHeader)
172+
}
173+
resp, err := client.httpClient.Do(req)
174+
if err != nil {
175+
return "", err
114176
}
115177

116178
respData, err := io.ReadAll(resp.Body)
117179
if err != nil {
118-
return
180+
return "", err
119181
}
120182

121183
if resp.StatusCode != 200 {
@@ -126,3 +188,33 @@ func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error
126188
err = json.Unmarshal(respData, &jobLogs)
127189
return jobLogs.Logs, err
128190
}
191+
192+
func (client *rayClusterClient) WaitForJobStatus(jobID string) (string, error) {
193+
var status string
194+
var prevStatus string
195+
fmt.Printf("Waiting for job to be Succeeded...\n")
196+
var err error
197+
var resp *RayJobDetailsResponse
198+
for status != "SUCCEEDED" {
199+
resp, err = client.GetJobDetails(jobID)
200+
if err != nil {
201+
time.Sleep(2 * time.Second)
202+
continue
203+
}
204+
statusVal := resp.Status
205+
if statusVal == "SUCCEEDED" || statusVal == "FAILED" {
206+
fmt.Printf("JobStatus : %s\n", statusVal)
207+
prevStatus = statusVal
208+
return prevStatus, err
209+
}
210+
if prevStatus != statusVal && statusVal != "SUCCEEDED" {
211+
fmt.Printf("JobStatus : %s...\n", statusVal)
212+
prevStatus = statusVal
213+
}
214+
time.Sleep(3 * time.Second)
215+
}
216+
if prevStatus != "SUCCEEDED" {
217+
err = fmt.Errorf("Job failed !")
218+
}
219+
return prevStatus, err
220+
}

0 commit comments

Comments
 (0)