Skip to content
This repository was archived by the owner on Sep 11, 2023. It is now read-only.

Commit 5499861

Browse files
authored
Merge pull request #453 from openclimatefix/jack/325-use-multiprocessing-pool-in-manager
Use `multiprocessing.Pool` instead of `ProcessPoolExecutor`
2 parents b04cac1 + ff0f11d commit 5499861

File tree

1 file changed

+38
-25
lines changed

1 file changed

+38
-25
lines changed

nowcasting_dataset/manager.py

+38-25
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Manager class."""
22

33
import logging
4-
from concurrent import futures
4+
import multiprocessing
55
from pathlib import Path
66
from typing import Optional, Union
77

@@ -419,24 +419,21 @@ def create_batches(self, overwrite_batches: bool) -> None:
419419
locations_for_each_example["t0_datetime_UTC"] = pd.to_datetime(
420420
locations_for_each_example["t0_datetime_UTC"]
421421
)
422-
locations_for_each_example_of_each_split[split_name] = locations_for_each_example
422+
if len(locations_for_each_example) > 0:
423+
locations_for_each_example_of_each_split[split_name] = locations_for_each_example
423424

424425
# Fire up a separate process for each DataSource, and pass it a list of batches to
425426
# create, and whether to utils.upload_and_delete_local_files().
426427
# TODO: Issue 321: Split this up into separate functions!!!
427428
n_data_sources = len(self.data_sources)
428429
nd_utils.set_fsspec_for_multiprocess()
429-
for split_name in splits_which_need_more_batches:
430-
locations_for_split = locations_for_each_example_of_each_split[split_name]
431-
with futures.ProcessPoolExecutor(max_workers=n_data_sources) as executor:
432-
future_create_batches_jobs = []
430+
for split_name, locations_for_split in locations_for_each_example_of_each_split.items():
431+
with multiprocessing.Pool(processes=n_data_sources) as pool:
432+
async_results_from_create_batches = []
433433
for worker_id, (data_source_name, data_source) in enumerate(
434434
self.data_sources.items()
435435
):
436436

437-
if len(locations_for_split) == 0:
438-
break
439-
440437
# Get indexes of first batch and example. And subset locations_for_split.
441438
idx_of_first_batch = first_batches_to_create[split_name][data_source_name]
442439
idx_of_first_example = idx_of_first_batch * self.config.process.batch_size
@@ -446,6 +443,8 @@ def create_batches(self, overwrite_batches: bool) -> None:
446443
dst_path = (
447444
self.config.output_data.filepath / split_name.value / data_source_name
448445
)
446+
447+
# TODO: Issue 455: Guarantee that local temp path is unique and empty.
449448
local_temp_path = (
450449
self.local_temp_path
451450
/ split_name.value
@@ -458,27 +457,41 @@ def create_batches(self, overwrite_batches: bool) -> None:
458457
if self.save_batches_locally_and_upload:
459458
nd_fs_utils.makedirs(local_temp_path, exist_ok=True)
460459

461-
# Submit data_source.create_batches task to the worker process.
462-
future = executor.submit(
463-
data_source.create_batches,
460+
# Key word arguments to be passed into data_source.create_batches():
461+
kwargs_for_create_batches = dict(
464462
spatial_and_temporal_locations_of_each_example=locations,
465463
idx_of_first_batch=idx_of_first_batch,
466464
batch_size=self.config.process.batch_size,
467465
dst_path=dst_path,
468466
local_temp_path=local_temp_path,
469467
upload_every_n_batches=self.config.process.upload_every_n_batches,
470468
)
471-
future_create_batches_jobs.append(future)
472469

473-
# Wait for all futures to finish:
474-
for future, data_source_name in zip(
475-
future_create_batches_jobs, self.data_sources.keys()
476-
):
477-
# Call exception() to propagate any exceptions raised by the worker process into
478-
# the main process, and to wait for the worker to finish.
479-
exception = future.exception()
480-
if exception is not None:
481-
logger.exception(
482-
f"Worker process {data_source_name} raised exception!\n{exception}"
483-
)
484-
raise exception
470+
# Logger messages for callbacks:
471+
callback_msg = (
472+
f"{data_source_name} has finished created batches for {split_name}!"
473+
)
474+
error_callback_msg = (
475+
f"Exception raised by {data_source_name} whilst creating batches for"
476+
f" {split_name}:\n"
477+
)
478+
479+
# Submit data_source.create_batches task to the worker process.
480+
logger.debug(
481+
f"About to submit create_batches task for {data_source_name}, {split_name}"
482+
)
483+
async_result = pool.apply_async(
484+
data_source.create_batches,
485+
kwds=kwargs_for_create_batches,
486+
callback=lambda result: logger.info(callback_msg),
487+
error_callback=lambda exception: logger.error(
488+
error_callback_msg + str(exception)
489+
),
490+
)
491+
async_results_from_create_batches.append(async_result)
492+
493+
# Wait for all async_results to finish:
494+
for async_result in async_results_from_create_batches:
495+
async_result.wait()
496+
497+
logger.info(f"Finished creating batches for {split_name}!")

0 commit comments

Comments
 (0)