Skip to content

Ray cluster client redesign #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions support/ray_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ func GetRayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) *RayJ

func WriteRayJobAPILogs(t Test, rayClient RayClusterClient, jobID string) {
t.T().Helper()
logs, err := rayClient.GetJobLogs(jobID)
jobLogs, err := rayClient.GetJobLogs(jobID)
t.Expect(err).NotTo(gomega.HaveOccurred())
WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(logs))
WriteToOutputDir(t, "ray-job-log-"+jobID, Log, []byte(jobLogs.Logs))
}

func RayJobAPIDetails(t Test, rayClient RayClusterClient, jobID string) func(g gomega.Gomega) *RayJobDetailsResponse {
Expand Down
114 changes: 41 additions & 73 deletions support/ray_cluster_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package support

import (
"bytes"
"crypto/tls"
"encoding/json"
"fmt"
"io"
Expand Down Expand Up @@ -47,39 +46,30 @@ type RayJobLogsResponse struct {
}

type RayClusterClientConfig struct {
Address string
Client *http.Client
InsecureSkipVerify bool
Address string
Client *http.Client
}

var _ RayClusterClient = (*rayClusterClient)(nil)

type rayClusterClient struct {
endpoint url.URL
httpClient *http.Client
bearerToken string
config RayClusterClientConfig
}

type RayClusterClient interface {
CreateJob(job *RayJobSetup) (*RayJobResponse, error)
GetJobDetails(jobID string) (*RayJobDetailsResponse, error)
GetJobLogs(jobID string) (string, error)
GetJobs() (*[]RayJobDetailsResponse, error)
GetJobLogs(jobID string) (*RayJobLogsResponse, error)
ListJobs() ([]RayJobDetailsResponse, error)
}

func NewRayClusterClient(config RayClusterClientConfig, bearerToken string) (RayClusterClient, error) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: config.InsecureSkipVerify},
Proxy: http.ProxyFromEnvironment,
}
config.Client = &http.Client{Transport: tr}
func NewRayClusterClient(config RayClusterClientConfig) (RayClusterClient, error) {
endpoint, err := url.Parse(config.Address)
if err != nil {
return nil, fmt.Errorf("invalid dashboard endpoint address")
}
rayClusterApiClient := &rayClusterClient{
endpoint: *endpoint, httpClient: config.Client, bearerToken: bearerToken,
return nil, fmt.Errorf("invalid dashboard endpoint address: %s", endpoint)
}

rayClusterApiClient := &rayClusterClient{config}
return rayClusterApiClient, nil
}

Expand All @@ -89,13 +79,15 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
return
}

createJobURL := client.endpoint.String() + "/api/jobs/"
createJobURL := client.config.Address + "/api/jobs/"

resp, err := client.httpClient.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
resp, err := client.config.Client.Post(createJobURL, "application/json", bytes.NewReader(marshalled))
if err != nil {
return
}

defer resp.Body.Close()

respData, err := io.ReadAll(resp.Body)
if err != nil {
return
Expand All @@ -109,95 +101,71 @@ func (client *rayClusterClient) CreateJob(job *RayJobSetup) (response *RayJobRes
return
}

func (client *rayClusterClient) GetJobs() (response *[]RayJobDetailsResponse, err error) {
getAllJobsDetailsURL := client.endpoint.String() + "/api/jobs/"
func (client *rayClusterClient) ListJobs() (response []RayJobDetailsResponse, err error) {
getAllJobsDetailsURL := client.config.Address + "/api/jobs/"

req, err := http.NewRequest(http.MethodGet, getAllJobsDetailsURL, nil)
resp, err := client.config.Client.Get(getAllJobsDetailsURL)
if err != nil {
return nil, err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
}
resp, err := client.httpClient.Do(req)
if err != nil {
return nil, err
return
}

defer resp.Body.Close()
if resp.StatusCode == 503 {
return nil, fmt.Errorf("service unavailable")
}

respData, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
return
}

if resp.StatusCode != 200 {
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData)
}

err = json.Unmarshal(respData, &response)
if err != nil {
return nil, err
}
return response, nil
return
}

func (client *rayClusterClient) GetJobDetails(jobID string) (response *RayJobDetailsResponse, err error) {
getJobDetailsURL := client.endpoint.String() + "/api/jobs/" + jobID
getJobDetailsURL := client.config.Address + "/api/jobs/" + jobID

req, err := http.NewRequest(http.MethodGet, getJobDetailsURL, nil)
resp, err := client.config.Client.Get(getJobDetailsURL)
if err != nil {
return nil, err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
return
}

resp, err := client.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == 503 {
return nil, fmt.Errorf("service unavailable")
}
defer resp.Body.Close()

respData, err := io.ReadAll(resp.Body)
if err != nil {
return
}

if resp.StatusCode != 200 {
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job details, response body: %s", resp.StatusCode, respData)
}

err = json.Unmarshal(respData, &response)
if err != nil {
return nil, err
}
return response, nil
return
}

func (client *rayClusterClient) GetJobLogs(jobID string) (logs string, err error) {
getJobLogsURL := client.endpoint.String() + "/api/jobs/" + jobID + "/logs"
req, err := http.NewRequest(http.MethodGet, getJobLogsURL, nil)
if err != nil {
return "", err
}
if client.bearerToken != "" {
req.Header.Set("Authorization", "Bearer "+client.bearerToken)
}
resp, err := client.httpClient.Do(req)
func (client *rayClusterClient) GetJobLogs(jobID string) (response *RayJobLogsResponse, err error) {
getJobLogsURL := client.config.Address + "/api/jobs/" + jobID + "/logs"

resp, err := client.config.Client.Get(getJobLogsURL)
if err != nil {
return "", err
return
}

defer resp.Body.Close()

respData, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
return
}

if resp.StatusCode != 200 {
return "", fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData)
return nil, fmt.Errorf("incorrect response code: %d for retrieving Ray Job logs, response body: %s", resp.StatusCode, respData)
}

jobLogs := RayJobLogsResponse{}
err = json.Unmarshal(respData, &jobLogs)
return jobLogs.Logs, err
err = json.Unmarshal(respData, &response)
return
}
44 changes: 44 additions & 0 deletions support/ray_cluster_client_helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
Copyright 2024.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package support

import (
"crypto/tls"
"net/http"

. "github.com/onsi/gomega"

"k8s.io/client-go/transport"
)

func GetRayClusterClient(t Test, dashboardURL, bearerToken string) RayClusterClient {
t.T().Helper()

// Skip TLS check to work on clusters with insecure certificates too
// Functionality intended just for testing purpose, DO NOT USE IN PRODUCTION
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
Proxy: http.ProxyFromEnvironment,
}
client, err := NewRayClusterClient(RayClusterClientConfig{
Address: dashboardURL,
Client: &http.Client{Transport: transport.NewBearerAuthRoundTripper(bearerToken, tr)},
})
t.Expect(err).NotTo(HaveOccurred())

return client
}