|
11 | 11 | from nowcasting_dataset import config
|
12 | 12 | from nowcasting_dataset.filesystem import utils as nd_fs_utils
|
13 | 13 | from nowcasting_dataset.data_sources import MAP_DATA_SOURCE_NAME_TO_CLASS, ALL_DATA_SOURCE_NAMES
|
| 14 | +from nowcasting_dataset.dataset.split import split |
14 | 15 |
|
15 | 16 | logger = logging.getLogger(__name__)
|
16 | 17 |
|
@@ -116,6 +117,9 @@ def create_files_specifying_spatial_and_temporal_locations_of_each_example_if_ne
|
116 | 117 | t0_datetimes = self.get_t0_datetimes_across_all_data_sources(
|
117 | 118 | freq=self.config.process.t0_datetime_frequency
|
118 | 119 | )
|
| 120 | + split_t0_datetimes = split.split_data( |
| 121 | + datetimes=t0_datetimes, method=self.config.process.split_method |
| 122 | + ) |
119 | 123 |
|
120 | 124 | def _locations_csv_file_exists(self) -> bool:
|
121 | 125 | "Check if filepath/train/spatial_and_temporal_locations_of_each_example.csv exists"
|
@@ -202,20 +206,11 @@ def sample_spatial_and_temporal_locations_for_examples(
|
202 | 206 | Each row of each the DataFrame specifies the position of each example, using
|
203 | 207 | columns: 't0_datetime_UTC', 'x_center_OSGB', 'y_center_OSGB'.
|
204 | 208 | """
|
205 |
| - # This code is for backwards-compatibility with code which expects the first DataSource |
206 |
| - # in the list to be used to define which DataSource defines the spatial location. |
207 |
| - # TODO: Remove this try block after implementing issue #213. |
208 |
| - try: |
209 |
| - data_source_which_defines_geospatial_locations = ( |
210 |
| - self.data_source_which_defines_geospatial_locations |
211 |
| - ) |
212 |
| - except AttributeError: |
213 |
| - data_source_which_defines_geospatial_locations = self[0] |
214 |
| - |
215 | 209 | shuffled_t0_datetimes = np.random.choice(t0_datetimes, size=n_examples)
|
216 |
| - x_locations, y_locations = data_source_which_defines_geospatial_locations.get_locations( |
217 |
| - shuffled_t0_datetimes |
218 |
| - ) |
| 210 | + ( |
| 211 | + x_locations, |
| 212 | + y_locations, |
| 213 | + ) = self.data_source_which_defines_geospatial_locations.get_locations(shuffled_t0_datetimes) |
219 | 214 | return pd.DataFrame(
|
220 | 215 | {
|
221 | 216 | "t0_datetime_UTC": shuffled_t0_datetimes,
|
|
0 commit comments