Skip to content

Commit 87baca6

Browse files
committed
Add integ tests
1 parent f70655a commit 87baca6

File tree

3 files changed

+86
-3
lines changed

3 files changed

+86
-3
lines changed

integ/test_codegen.py

+7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
iris_df = iris_df[["target"] + [col for col in iris_df.columns if col != "target"]]
4343
train_data, test_data = train_test_split(iris_df, test_size=0.2, random_state=42)
4444
train_data.to_csv("./data/train.csv", index=False, header=False)
45+
test_data_no_target = test_data.drop('target', axis=1)
4546

4647
# Upload Data
4748
prefix = "DEMO-scikit-iris"
@@ -149,6 +150,12 @@ def test_training_and_inference(self):
149150
)
150151
endpoint.wait_for_status("InService")
151152

153+
invoke_result = endpoint.invoke(body=test_data_no_target.to_csv(header=False, index=False),
154+
content_type='text/csv',
155+
accept='text/csv')
156+
157+
print(invoke_result)
158+
152159
def test_intelligent_defaults(self):
153160
os.environ["SAGEMAKER_CORE_ADMIN_CONFIG_OVERRIDE"] = (
154161
self._setup_intelligent_default_configs_and_fetch_path()

integ/test_experiment_and_trial.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import datetime
2+
import time
3+
import unittest
4+
5+
from sagemaker_core.helper.session_helper import Session, get_execution_role
6+
from sagemaker_core.main.resources import Experiment, Trial, TrialComponent
7+
from sagemaker_core.main.shapes import RawMetricData, TrialComponentParameterValue
8+
from sagemaker_core.main.utils import get_textual_rich_logger
9+
10+
logger = get_textual_rich_logger(__name__)
11+
12+
sagemaker_session = Session()
13+
region = sagemaker_session.boto_region_name
14+
role = get_execution_role()
15+
bucket = sagemaker_session.default_bucket()
16+
17+
18+
class TestExperimentAndTrial(unittest.TestCase):
19+
def test_experiment_and_trial(self):
20+
experiment_name = "local-pyspark-experiment-example-" + time.strftime(
21+
"%Y-%m-%d-%H-%M-%S", time.gmtime()
22+
)
23+
run_group_name = "Default-Run-Group-" + experiment_name
24+
run_name = "local-experiment-run-" + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
25+
26+
experiment = Experiment.create(experiment_name=experiment_name)
27+
trial = Trial.create(trial_name=run_group_name, experiment_name=experiment_name)
28+
29+
created_after = datetime.datetime.now() - datetime.timedelta(days=5)
30+
experiments_iterator = Experiment.get_all(created_after=created_after)
31+
experiments = [exp.experiment_name for exp in experiments_iterator]
32+
33+
assert len(experiments) > 0
34+
assert experiment.experiment_name in experiments
35+
36+
trial_component_parameters = {
37+
"num_train_samples": TrialComponentParameterValue(number_value=5),
38+
"num_test_samples": TrialComponentParameterValue(number_value=5),
39+
}
40+
41+
trial_component = TrialComponent.create(
42+
trial_component_name=run_name,
43+
parameters=trial_component_parameters,
44+
)
45+
trial_component.associate_trail(trial_name=trial.trial_name)
46+
47+
training_parameters = {
48+
"device": TrialComponentParameterValue(string_value="cpu"),
49+
"data_dir": TrialComponentParameterValue(string_value="test"),
50+
"optimizer": TrialComponentParameterValue(string_value="sgd"),
51+
"epochs": TrialComponentParameterValue(number_value=5),
52+
"hidden_channels": TrialComponentParameterValue(number_value=10),
53+
}
54+
trial_component.update(parameters=training_parameters)
55+
56+
metrics = []
57+
for i in range(5):
58+
accuracy_metric = RawMetricData(
59+
metric_name="test:accuracy",
60+
value=i / 10,
61+
step=i,
62+
timestamp=time.time(),
63+
)
64+
metrics.append(accuracy_metric)
65+
66+
trial_component.batch_put_metrics(metric_data=metrics)
67+
68+
time.sleep(10)
69+
trial_component.refresh()
70+
71+
assert len(trial_component.parameters) == 7
72+
assert len(trial_component.metrics) == 1
73+
assert trial_component.metrics[0].count == 5

src/sagemaker_core/main/utils.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ def _serialize_dict(value: Dict) -> dict:
500500
"""
501501
serialized_dict = {}
502502
for k, v in value.items():
503-
if serialize_result := serialize(v):
503+
serialize_result = serialize(v)
504+
if serialize_result is not None:
504505
serialized_dict.update({k: serialize_result})
505506
return serialized_dict
506507

@@ -517,7 +518,8 @@ def _serialize_list(value: List) -> list:
517518
"""
518519
serialized_list = []
519520
for v in value:
520-
if serialize_result := serialize(v):
521+
serialize_result = serialize(v)
522+
if serialize_result is not None:
521523
serialized_list.append(serialize_result)
522524
return serialized_list
523525

@@ -534,7 +536,8 @@ def _serialize_shape(value: Any) -> dict:
534536
"""
535537
serialized_dict = {}
536538
for k, v in vars(value).items():
537-
if serialize_result := serialize(v):
539+
serialize_result = serialize(v)
540+
if serialize_result is not None:
538541
key = snake_to_pascal(k) if is_snake_case(k) else k
539542
serialized_dict.update({key[0].upper() + key[1:]: serialize_result})
540543
return serialized_dict

0 commit comments

Comments
 (0)