Skip to content

Commit 2f60dc7

Browse files
sutaakaropenshift-merge-bot[bot]
authored andcommitted
Ray cluster client redesign
1 parent 848538d commit 2f60dc7

File tree

3 files changed

+87
-75
lines changed

3 files changed

+87
-75
lines changed

support/ray_api.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ func GetRayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) *RayJ
2727

2828
func WriteRayJobAPILogs(t Test, rayClient RayClusterClient, jobID string) {
2929
t.T().Helper()
30-
logs, err := rayClient.GetJobLogs(jobID)
30+
jobLogs, err := rayClient.GetJobLogs(jobID)
3131
t.Expect(err).NotTo(gomega.HaveOccurred())
32-
WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(logs))
32+
WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(jobLogs.Logs))
3333
}
3434

3535
func RayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) func(g gomega.Gomega) *RayJobDetailsResponse {

support/ray_cluster_client.go

+41-73
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package support
1818

1919
import (
2020
"bytes"
21-
"crypto/tls"
2221
"encoding/json"
2322
"fmt"
2423
"io"
@@ -47,39 +46,30 @@ type RayJobLogsResponse struct {
4746
}
4847

4948
type RayClusterClientConfig struct {
50-
Address string
51-
Client *http.Client
52-
InsecureSkipVerify bool
49+
Address string
50+
Client *http.Client
5351
}
5452

5553
var _ RayClusterClient = (*rayClusterClient)(nil)
5654

5755
type rayClusterClient struct {
58-
endpoint url.URL
59-
httpClient *http.Client
60-
bearerToken string
56+
config RayClusterClientConfig
6157
}
6258

6359
type RayClusterClient interface {
6460
CreateJob(job *RayJobSetup) (*RayJobResponse, error)
6561
GetJobDetails(jobID string) (*RayJobDetailsResponse, error)
66-
GetJobLogs(jobID string) (string, error)
67-
GetJobs() (*[]RayJobDetailsResponse, error)
62+
GetJobLogs(jobID string) (*RayJobLogsResponse, error)
63+
ListJobs() ([]RayJobDetailsResponse, error)
6864
}
6965

70-
func NewRayClusterClient(config RayClusterClientConfig, bearerToken string) (RayClusterClient, error) {
71-
tr := &http.Transport{
72-
TLSClientConfig: &tls.Config{InsecureSkipVerify: config.InsecureSkipVerify},
73-
Proxy: http.ProxyFromEnvironment,
74-
}
75-
config.Client = &http.Client{Transport: tr}
66+
func NewRayClusterClient(config RayClusterClientConfig) (RayClusterClient, error) {
7667
endpoint, err := url.Parse(config.Address)
7768
if err != nil {
78-
return nil, fmt.Errorf("invalid dashboard endpoint address")
79-
}
80-
rayClusterApiClient := &rayClusterClient{
81-
endpoint: *endpoint, httpClient: config.Client, bearerToken: bearerToken,
69+
return nil, fmt.Errorf("invalid dashboard endpoint address: %s", endpoint)
8270
}
71+
72+
rayClusterApiClient := &rayClusterClient{config}
8373
return rayClusterApiClient, nil
8474
}
8575

@@ -89,13 +79,15 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
8979
return
9080
}
9181

92-
createJobURL := client.endpoint.String() + "/api/jobs/"
82+
createJobURL := client.config.Address + "/api/jobs/"
9383

94-
resp, err := client.httpClient.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
84+
resp, err := client.config.Client.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
9585
if err != nil {
9686
return
9787
}
9888

89+
defer resp.Body.Close()
90+
9991
respData, err := io.ReadAll(resp.Body)
10092
if err != nil {
10193
return
@@ -109,95 +101,71 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
109101
return
110102
}
111103

112-
func (client *rayClusterClient) GetJobs() (response *[]RayJobDetailsResponse, err error) {
113-
getAllJobsDetailsURL := client.endpoint.String() + "/api/jobs/"
104+
func (client *rayClusterClient) ListJobs() (response []RayJobDetailsResponse, err error) {
105+
getAllJobsDetailsURL := client.config.Address + "/api/jobs/"
114106

115-
req, err := http.NewRequest(http.MethodGet, getAllJobsDetailsURL, nil)
107+
resp, err := client.config.Client.Get(getAllJobsDetailsURL)
116108
if err != nil {
117-
return nil, err
118-
}
119-
if client.bearerToken != "" {
120-
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
121-
}
122-
resp, err := client.httpClient.Do(req)
123-
if err != nil {
124-
return nil, err
109+
return
125110
}
111+
126112
defer resp.Body.Close()
127-
if resp.StatusCode == 503 {
128-
return nil, fmt.Errorf("service unavailable")
129-
}
113+
130114
respData, err := io.ReadAll(resp.Body)
131115
if err != nil {
132-
return nil, err
116+
return
133117
}
118+
134119
if resp.StatusCode != 200 {
135120
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData)
136121
}
122+
137123
err = json.Unmarshal(respData, &response)
138-
if err != nil {
139-
return nil, err
140-
}
141-
return response, nil
124+
return
142125
}
143126

144127
func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDetailsResponse, err error) {
145-
getJobDetailsURL := client.endpoint.String() + "/api/jobs/" + jobID
128+
getJobDetailsURL := client.config.Address + "/api/jobs/" + jobID
146129

147-
req, err := http.NewRequest(http.MethodGet, getJobDetailsURL, nil)
130+
resp, err := client.config.Client.Get(getJobDetailsURL)
148131
if err != nil {
149-
return nil, err
150-
}
151-
if client.bearerToken != "" {
152-
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
132+
return
153133
}
154134

155-
resp, err := client.httpClient.Do(req)
156-
if err != nil {
157-
return nil, err
158-
}
159-
if resp.StatusCode == 503 {
160-
return nil, fmt.Errorf("service unavailable")
161-
}
135+
defer resp.Body.Close()
162136

163137
respData, err := io.ReadAll(resp.Body)
164138
if err != nil {
165139
return
166140
}
141+
167142
if resp.StatusCode != 200 {
168143
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData)
169144
}
145+
170146
err = json.Unmarshal(respData, &response)
171-
if err != nil {
172-
return nil, err
173-
}
174-
return response, nil
147+
return
175148
}
176149

177-
func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error) {
178-
getJobLogsURL := client.endpoint.String() + "/api/jobs/" + jobID + "/logs"
179-
req, err := http.NewRequest(http.MethodGet, getJobLogsURL, nil)
180-
if err != nil {
181-
return "", err
182-
}
183-
if client.bearerToken != "" {
184-
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
185-
}
186-
resp, err := client.httpClient.Do(req)
150+
func (client *rayClusterClient) GetJobLogs(jobID string) (response *RayJobLogsResponse, err error) {
151+
getJobLogsURL := client.config.Address + "/api/jobs/" + jobID + "/logs"
152+
153+
resp, err := client.config.Client.Get(getJobLogsURL)
187154
if err != nil {
188-
return "", err
155+
return
189156
}
190157

158+
defer resp.Body.Close()
159+
191160
respData, err := io.ReadAll(resp.Body)
192161
if err != nil {
193-
return "", err
162+
return
194163
}
195164

196165
if resp.StatusCode != 200 {
197-
return "", fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData)
166+
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData)
198167
}
199168

200-
jobLogs := RayJobLogsResponse{}
201-
err = json.Unmarshal(respData, &jobLogs)
202-
return jobLogs.Logs, err
169+
err = json.Unmarshal(respData, &response)
170+
return
203171
}

support/ray_cluster_client_helper.go

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
/*
2+
Copyright 2024.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
17+
package support
18+
19+
import (
20+
"crypto/tls"
21+
"net/http"
22+
23+
. "github.com/onsi/gomega"
24+
25+
"k8s.io/client-go/transport"
26+
)
27+
28+
func GetRayClusterClient(t Test, dashboardURL, bearerToken string) RayClusterClient {
29+
t.T().Helper()
30+
31+
// Skip TLS check to work on clusters with insecure certificates too
32+
// Functionality intended just for testing purpose, DO NOT USE IN PRODUCTION
33+
tr := &http.Transport{
34+
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
35+
Proxy: http.ProxyFromEnvironment,
36+
}
37+
client, err := NewRayClusterClient(RayClusterClientConfig{
38+
Address: dashboardURL,
39+
Client: &http.Client{Transport: transport.NewBearerAuthRoundTripper(bearerToken, tr)},
40+
})
41+
t.Expect(err).NotTo(HaveOccurred())
42+
43+
return client
44+
}

0 commit comments

Comments
 (0)