Skip to content

Commit dba361b

Browse files
Added parameters for S3 compatible storage bucket from which to download MNIST datasets
1 parent 22087f8 commit dba361b

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

support/environment.go

+39-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ const (
2929
CodeFlareTestRayVersion = "CODEFLARE_TEST_RAY_VERSION"
3030
CodeFlareTestRayImage = "CODEFLARE_TEST_RAY_IMAGE"
3131
CodeFlareTestPyTorchImage = "CODEFLARE_TEST_PYTORCH_IMAGE"
32+
NotebookImage = "NOTEBOOK_IMAGE"
3233

3334
// The testing output directory, to write output files into.
3435
CodeFlareTestOutputDir = "CODEFLARE_TEST_OUTPUT_DIR"
@@ -51,6 +52,13 @@ const (
5152
// URL for PiPI index containing all the required test Python packages
5253
pipIndexURL = "PIP_INDEX_URL"
5354
pipTrustedHost = "PIP_TRUSTED_HOST"
55+
56+
// Storage bucket credentials
57+
storageDefaultEndpoint = "AWS_DEFAULT_ENDPOINT"
58+
storageAccessKeyId = "AWS_ACCESS_KEY_ID"
59+
storageSecretKey = "AWS_SECRET_ACCESS_KEY"
60+
storageBucketName = "AWS_STORAGE_BUCKET"
61+
storageBucketMnistDir = "AWS_STORAGE_BUCKET_MNIST_DIR"
5462
)
5563

5664
type ClusterType string
@@ -79,6 +87,11 @@ func GetPyTorchImage() string {
7987
return lookupEnvOrDefault(CodeFlareTestPyTorchImage, "pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime")
8088
}
8189

90+
func GetNotebookImage() string {
91+
notebook_image, _ := os.LookupEnv(NotebookImage)
92+
return notebook_image
93+
}
94+
8295
func GetInstascaleOcmSecret() (string, string) {
8396
res := strings.SplitN(lookupEnvOrDefault(InstaScaleOcmSecret, "default/instascale-ocm-secret"), "/", 2)
8497
return res[0], res[1]
@@ -118,7 +131,32 @@ func GetClusterHostname(t Test) string {
118131
}
119132

120133
func GetMnistDatasetURL() string {
121-
return lookupEnvOrDefault(mnistDatasetURL, "http://yann.lecun.com/exdb/mnist/")
134+
return lookupEnvOrDefault(mnistDatasetURL, "https://ossci-datasets.s3.amazonaws.com/mnist/")
135+
}
136+
137+
func GetStorageBucketDefaultEndpoint() (string, bool) {
138+
storage_endpoint, exists := os.LookupEnv(storageDefaultEndpoint)
139+
return storage_endpoint, exists
140+
}
141+
142+
func GetStorageBucketAccessKeyId() (string, bool) {
143+
storage_access_key_id, exists := os.LookupEnv(storageAccessKeyId)
144+
return storage_access_key_id, exists
145+
}
146+
147+
func GetStorageBucketSecretKey() (string, bool) {
148+
storage_secret_key, exists := os.LookupEnv(storageSecretKey)
149+
return storage_secret_key, exists
150+
}
151+
152+
func GetStorageBucketname() (string, bool) {
153+
storage_bucket_name, exists := os.LookupEnv(storageBucketName)
154+
return storage_bucket_name, exists
155+
}
156+
157+
func GetStorageBucketMnistDir() (string, bool) {
158+
storage_bucket_mnist_dir, exists := os.LookupEnv(storageBucketMnistDir)
159+
return storage_bucket_mnist_dir, exists
122160
}
123161

124162
func GetPipIndexURL() string {

0 commit comments

Comments
 (0)