1
1
"""Manager class."""
2
2
3
3
import logging
4
- from concurrent import futures
4
+ import multiprocessing
5
5
from pathlib import Path
6
6
from typing import Optional , Union
7
7
@@ -419,24 +419,21 @@ def create_batches(self, overwrite_batches: bool) -> None:
419
419
locations_for_each_example ["t0_datetime_UTC" ] = pd .to_datetime (
420
420
locations_for_each_example ["t0_datetime_UTC" ]
421
421
)
422
- locations_for_each_example_of_each_split [split_name ] = locations_for_each_example
422
+ if len (locations_for_each_example ) > 0 :
423
+ locations_for_each_example_of_each_split [split_name ] = locations_for_each_example
423
424
424
425
# Fire up a separate process for each DataSource, and pass it a list of batches to
425
426
# create, and whether to utils.upload_and_delete_local_files().
426
427
# TODO: Issue 321: Split this up into separate functions!!!
427
428
n_data_sources = len (self .data_sources )
428
429
nd_utils .set_fsspec_for_multiprocess ()
429
- for split_name in splits_which_need_more_batches :
430
- locations_for_split = locations_for_each_example_of_each_split [split_name ]
431
- with futures .ProcessPoolExecutor (max_workers = n_data_sources ) as executor :
432
- future_create_batches_jobs = []
430
+ for split_name , locations_for_split in locations_for_each_example_of_each_split .items ():
431
+ with multiprocessing .Pool (processes = n_data_sources ) as pool :
432
+ async_results_from_create_batches = []
433
433
for worker_id , (data_source_name , data_source ) in enumerate (
434
434
self .data_sources .items ()
435
435
):
436
436
437
- if len (locations_for_split ) == 0 :
438
- break
439
-
440
437
# Get indexes of first batch and example. And subset locations_for_split.
441
438
idx_of_first_batch = first_batches_to_create [split_name ][data_source_name ]
442
439
idx_of_first_example = idx_of_first_batch * self .config .process .batch_size
@@ -446,6 +443,8 @@ def create_batches(self, overwrite_batches: bool) -> None:
446
443
dst_path = (
447
444
self .config .output_data .filepath / split_name .value / data_source_name
448
445
)
446
+
447
+ # TODO: Issue 455: Guarantee that local temp path is unique and empty.
449
448
local_temp_path = (
450
449
self .local_temp_path
451
450
/ split_name .value
@@ -458,27 +457,41 @@ def create_batches(self, overwrite_batches: bool) -> None:
458
457
if self .save_batches_locally_and_upload :
459
458
nd_fs_utils .makedirs (local_temp_path , exist_ok = True )
460
459
461
- # Submit data_source.create_batches task to the worker process.
462
- future = executor .submit (
463
- data_source .create_batches ,
460
+ # Key word arguments to be passed into data_source.create_batches():
461
+ kwargs_for_create_batches = dict (
464
462
spatial_and_temporal_locations_of_each_example = locations ,
465
463
idx_of_first_batch = idx_of_first_batch ,
466
464
batch_size = self .config .process .batch_size ,
467
465
dst_path = dst_path ,
468
466
local_temp_path = local_temp_path ,
469
467
upload_every_n_batches = self .config .process .upload_every_n_batches ,
470
468
)
471
- future_create_batches_jobs .append (future )
472
469
473
- # Wait for all futures to finish:
474
- for future , data_source_name in zip (
475
- future_create_batches_jobs , self .data_sources .keys ()
476
- ):
477
- # Call exception() to propagate any exceptions raised by the worker process into
478
- # the main process, and to wait for the worker to finish.
479
- exception = future .exception ()
480
- if exception is not None :
481
- logger .exception (
482
- f"Worker process { data_source_name } raised exception!\n { exception } "
483
- )
484
- raise exception
470
+ # Logger messages for callbacks:
471
+ callback_msg = (
472
+ f"{ data_source_name } has finished created batches for { split_name } !"
473
+ )
474
+ error_callback_msg = (
475
+ f"Exception raised by { data_source_name } whilst creating batches for"
476
+ f" { split_name } :\n "
477
+ )
478
+
479
+ # Submit data_source.create_batches task to the worker process.
480
+ logger .debug (
481
+ f"About to submit create_batches task for { data_source_name } , { split_name } "
482
+ )
483
+ async_result = pool .apply_async (
484
+ data_source .create_batches ,
485
+ kwds = kwargs_for_create_batches ,
486
+ callback = lambda result : logger .info (callback_msg ),
487
+ error_callback = lambda exception : logger .error (
488
+ error_callback_msg + str (exception )
489
+ ),
490
+ )
491
+ async_results_from_create_batches .append (async_result )
492
+
493
+ # Wait for all async_results to finish:
494
+ for async_result in async_results_from_create_batches :
495
+ async_result .wait ()
496
+
497
+ logger .info (f"Finished creating batches for { split_name } !" )
0 commit comments