Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit ce6e5ba

Browse files
committed
check columns names
1 parent 51f3640 commit ce6e5ba

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

nowcasting_dataset/data_sources/data_source.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def create_batches(
167167
assert batch_size > 0
168168
assert len(spatial_and_temporal_locations_of_each_example) % batch_size == 0
169169
assert upload_every_n_batches >= 0
170-
assert spatial_and_temporal_locations_of_each_example.columns.to_tuple() == (
170+
assert spatial_and_temporal_locations_of_each_example.columns.to_list() == list(
171171
SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES
172172
)
173173

nowcasting_dataset/manager.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def create_batches(self, overwrite_batches: bool) -> None:
342342
for split_name in splits_which_need_more_batches:
343343
filename = self._filename_of_locations_csv_file(split_name.value)
344344
logger.info(f"Loading {filename}.")
345-
locations_for_each_example = pd.read_csv(filename)
346-
assert locations_for_each_example.columns.to_tuple() == (
345+
locations_for_each_example = pd.read_csv(filename, index_col=0)
346+
assert locations_for_each_example.columns.to_list() == list(
347347
SPATIAL_AND_TEMPORAL_LOCATIONS_COLUMN_NAMES
348348
)
349349
# Converting to datetimes is much faster using `pd.to_datetime()` than
@@ -399,7 +399,12 @@ def create_batches(self, overwrite_batches: bool) -> None:
399399
future_create_batches_jobs.append(future)
400400

401401
# Wait for all futures to finish:
402-
for future in future_create_batches_jobs:
402+
for future, data_source_name in zip(
403+
future_create_batches_jobs, self.data_sources.keys()
404+
):
403405
# Call exception() to propagate any exceptions raised by the worker process into
404406
# the main process, and to wait for the worker to finish.
405-
future.exception()
407+
exception = future.exception()
408+
if exception is not None:
409+
logger.exception(f"Worker process {data_source_name} raised exception!")
410+
raise exception

scripts/prepare_ml_data.py

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def main(config_filename: str, data_source: list[str], overwrite_batches: bool):
6868
manager.create_batches(overwrite_batches)
6969
# TODO: Issue #316: save_yaml_configuration(config)
7070
# TODO: Issue #317: Validate ML data.
71+
logger.info("Done!")
7172

7273

7374
if __name__ == "__main__":

0 commit comments

Comments
 (0)