Skip to content

Commit 1086b7a

Browse files
committed
Add support for torch-amd and torch-rocm image
1 parent 11fd6e3 commit 1086b7a

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

Diff for: support/defaults.go

+5-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ package support
55
// ***********************
66

77
const (
8-
RayVersion = "2.35.0"
9-
RayImage = "quay.io/modh/ray:2.35.0-py39-cu121"
10-
RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61"
8+
RayVersion = "2.35.0"
9+
RayImage = "quay.io/modh/ray:2.35.0-py39-cu121"
10+
RayROCmImage = "quay.io/modh/ray:2.35.0-py39-rocm61"
11+
RayTorchCudaImage = "quay.io/rhoai/2.35.0-py39-cu121-torch24-fa26"
12+
RayTorchROCmImage = "quay.io/rhoai/ray:2.35.0-py39-rocm61-torch24-fa26"
1113
)

Diff for: support/environment.go

+14-4
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@ const (
2525
// The environment variables hereafter can be used to change the components
2626
// used for testing.
2727

28-
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
29-
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
30-
CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE"
31-
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
28+
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
29+
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
30+
CodeFlareTestRayROCmImage = "CODEFLARE_TEST_RAY_ROCM_IMAGE"
31+
CodeFlareTestRayTorchCudaImage = "CODEFLARE_TEST_RAY_TORCH_CUDA_IMAGE"
32+
CodeFlareTestRayTorchROCmImage = "CODEFLARE_TEST_RAY_TORCH_ROCM_IMAGE"
33+
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
3234

3335
// The testing output directory, to write output files into.
3436
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
@@ -83,6 +85,14 @@ func GetRayROCmImage() string {
8385
return lookupEnvOrDefault(CodeFlareTestRayROCmImage, RayROCmImage)
8486
}
8587

88+
func GetRayTorchCudaImage() string {
89+
return lookupEnvOrDefault(CodeFlareTestRayTorchCudaImage, RayTorchCudaImage)
90+
}
91+
92+
func GetRayTorchROCmImage() string {
93+
return lookupEnvOrDefault(CodeFlareTestRayTorchROCmImage, RayTorchROCmImage)
94+
}
95+
8696
func GetPyTorchImage() string {
8797
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
8898
}

0 commit comments

Comments
 (0)