Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 8bef05c

Browse files
committed
Check there is no overlap between split datetimes. Fixes #299
1 parent 8d5043b commit 8bef05c

File tree

3 files changed

+23
-14
lines changed

3 files changed

+23
-14
lines changed

nowcasting_dataset/config/model.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@
1919
from pydantic import BaseModel, Field
2020
from pydantic import validator, root_validator
2121

22-
from nowcasting_dataset.consts import NWP_VARIABLE_NAMES
22+
# nowcasting_dataset imports
2323
from nowcasting_dataset.consts import (
24+
NWP_VARIABLE_NAMES,
2425
SAT_VARIABLE_NAMES,
2526
DEFAULT_N_GSP_PER_EXAMPLE,
2627
DEFAULT_N_PV_SYSTEMS_PER_EXAMPLE,
2728
)
29+
from nowcasting_dataset.dataset.split import split
2830

2931

3032
IMAGE_SIZE_PIXELS_FIELD = Field(64, description="The number of pixels of the region of interest.")
@@ -278,6 +280,12 @@ class Process(BaseModel):
278280
" return data at 11:30, 12:00, 12:30, and 13:00."
279281
),
280282
)
283+
split_method: split.SplitMethod = Field(
284+
split.SplitMethod.DAY,
285+
description=(
286+
"The method used to split the t0 datetimes into train, validation and test sets."
287+
),
288+
)
281289
upload_every_n_batches: int = Field(
282290
16,
283291
description=(

nowcasting_dataset/dataset/split/split.py

+6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class SplitName(Enum):
3939
TEST = "test"
4040

4141

42+
# Create a namedtuple for storing split t0 datetimes.
4243
SplitData = namedtuple(
4344
typename="SplitData",
4445
field_names=[SplitName.TRAIN.value, SplitName.VALIDATION.value, SplitName.TEST.value],
@@ -186,4 +187,9 @@ def split_data(
186187
f"test has {len(test_datetimes):,d} t0 datetimes."
187188
)
188189

190+
# Check there's no overlap.
191+
assert len(train_datetimes.intersection(validation_datetimes)) == 0
192+
assert len(train_datetimes.intersection(test_datetimes)) == 0
193+
assert len(test_datetimes.intersection(validation_datetimes)) == 0
194+
189195
return SplitData(train=train_datetimes, validation=validation_datetimes, test=test_datetimes)

nowcasting_dataset/manager.py

+8-13
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from nowcasting_dataset import config
1212
from nowcasting_dataset.filesystem import utils as nd_fs_utils
1313
from nowcasting_dataset.data_sources import MAP_DATA_SOURCE_NAME_TO_CLASS, ALL_DATA_SOURCE_NAMES
14+
from nowcasting_dataset.dataset.split import split
1415

1516
logger = logging.getLogger(__name__)
1617

@@ -116,6 +117,9 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne
116117
t0_datetimes = self.get_t0_datetimes_across_all_data_sources(
117118
freq=self.config.process.t0_datetime_frequency
118119
)
120+
split_t0_datetimes = split.split_data(
121+
datetimes=t0_datetimes, method=self.config.process.split_method
122+
)
119123

120124
def _locations_csv_file_exists(self) -> bool:
121125
"Check if filepath/train/spatial_and_temporal_locations_of_each_example.csv exists"
@@ -202,20 +206,11 @@ def sample_spatial_and_temporal_locations_for_examples(
202206
Each row of each the DataFrame specifies the position of each example, using
203207
columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
204208
"""
205-
# This code is for backwards-compatibility with code which expects the first DataSource
206-
# in the list to be used to define which DataSource defines the spatial location.
207-
# TODO: Remove this try block after implementing issue #213.
208-
try:
209-
data_source_which_defines_geospatial_locations = (
210-
self.data_source_which_defines_geospatial_locations
211-
)
212-
except AttributeError:
213-
data_source_which_defines_geospatial_locations = self[0]
214-
215209
shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
216-
x_locations, y_locations = data_source_which_defines_geospatial_locations.get_locations(
217-
shuffled_t0_datetimes
218-
)
210+
(
211+
x_locations,
212+
y_locations,
213+
) = self.data_source_which_defines_geospatial_locations.get_locations(shuffled_t0_datetimes)
219214
return pd.DataFrame(
220215
{
221216
"t0_datetime_UTC": shuffled_t0_datetimes,

0 commit comments

Comments
 (0)