@@ -118,7 +118,8 @@ def count_call_alleles(
118
118
n_alleles = ds .dims ["alleles" ]
119
119
G = da .asarray (ds [call_genotype ])
120
120
shape = (G .chunks [0 ], G .chunks [1 ], n_alleles )
121
- N = da .empty (n_alleles , dtype = np .uint8 )
121
+ # use numpy array to avoid dask task dependencies between chunks
122
+ N = np .empty (n_alleles , dtype = np .uint8 )
122
123
new_ds = create_dataset (
123
124
{
124
125
variables .call_allele_count : (
@@ -263,8 +264,10 @@ def count_cohort_alleles(
263
264
ds , variables .call_allele_count , call_allele_count , count_call_alleles
264
265
)
265
266
variables .validate (ds , {call_allele_count : variables .call_allele_count_spec })
266
- AC , SC = da .asarray (ds [call_allele_count ]), da .asarray (ds [sample_cohort ])
267
- n_cohorts = SC .max ().compute () + 1 # 0-based indexing
267
+ # ensure cohorts is a numpy array to minimize dask task
268
+ # dependencies between chunks in other dimensions
269
+ AC , SC = da .asarray (ds [call_allele_count ]), ds [sample_cohort ].values
270
+ n_cohorts = SC .max () + 1 # 0-based indexing
268
271
AC = cohort_sum (AC , SC , n_cohorts , axis = 1 )
269
272
new_ds = create_dataset (
270
273
{variables .cohort_allele_count : (("variants" , "cohorts" , "alleles" ), AC )}
0 commit comments