Skip to content

Commit e99e941

Browse files
Add support for environment variables for Training images
1 parent 2a7c1fc commit e99e941

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

support/defaults.go

+2
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ const (
1010
RayROCmImage = "quay.io/modh/ray:2.35.0-py311-rocm61"
1111
RayTorchCudaImage = "quay.io/rhoai/ray:2.35.0-py311-cu121-torch24-fa26"
1212
RayTorchROCmImage = "quay.io/rhoai/ray:2.35.0-py311-rocm61-torch24-fa26"
13+
TrainingCudaImage = "quay.io/modh/training:py311-cuda121-torch241"
14+
TrainingROCmImage = "quay.io/modh/training:py311-rocm61-torch241"
1315
)

support/environment.go

+12-3
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ const (
2525
// The environment variables hereafter can be used to change the components
2626
// used for testing.
2727

28-
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
29-
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
30-
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
28+
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
29+
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
30+
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
31+
CodeFlareTestTrainingImage = "CODEFLARE_TEST_TRAINING_IMAGE"
3132

3233
// The testing output directory, to write output files into.
3334
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
@@ -97,6 +98,14 @@ func GetPyTorchImage() string {
9798
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
9899
}
99100

101+
func GetCudaTrainingImage() string {
102+
return lookupEnvOrDefault(CodeFlareTestTrainingImage, TrainingCudaImage)
103+
}
104+
105+
func GetROCmTrainingImage() string {
106+
return lookupEnvOrDefault(CodeFlareTestTrainingImage, TrainingROCmImage)
107+
}
108+
100109
func GetInstascaleOcmSecret() (string, string) {
101110
res := strings.SplitN(lookupEnvOrDefault(InstaScaleOcmSecret, "default/instascale-ocm-secret"), "/", 2)
102111
return res[0], res[1]

support/environment_test.go

+15
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,21 @@ func TestGetPyTorchImage(t *testing.T) {
5353

5454
}
5555

56+
func TestGetTrainingImage(t *testing.T) {
57+
58+
g := gomega.NewGomegaWithT(t)
59+
// Set the environment variable.
60+
os.Setenv(CodeFlareTestTrainingImage, "training/training:latest")
61+
62+
// Get the image.
63+
image := GetCudaTrainingImage()
64+
65+
// Assert that the image is correct.
66+
67+
g.Expect(image).To(gomega.Equal("training/training:latest"), "Expected image training/training:latest, but got %s", image)
68+
69+
}
70+
5671
func TestGetClusterID(t *testing.T) {
5772

5873
g := gomega.NewGomegaWithT(t)

0 commit comments

Comments
 (0)