Skip to content

Commit 11d4427

Browse files
timothymillarmergify[bot]
authored andcommitted
Avoid unnecessary task dependencies by using numpy arrays
1 parent c13acb1 commit 11d4427

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

sgkit/stats/aggregation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def count_call_alleles(
118118
n_alleles = ds.dims["alleles"]
119119
G = da.asarray(ds[call_genotype])
120120
shape = (G.chunks[0], G.chunks[1], n_alleles)
121-
N = da.empty(n_alleles, dtype=np.uint8)
121+
# use numpy array to avoid dask task dependencies between chunks
122+
N = np.empty(n_alleles, dtype=np.uint8)
122123
new_ds = create_dataset(
123124
{
124125
variables.call_allele_count: (
@@ -263,8 +264,10 @@ def count_cohort_alleles(
263264
ds, variables.call_allele_count, call_allele_count, count_call_alleles
264265
)
265266
variables.validate(ds, {call_allele_count: variables.call_allele_count_spec})
266-
AC, SC = da.asarray(ds[call_allele_count]), da.asarray(ds[sample_cohort])
267-
n_cohorts = SC.max().compute() + 1 # 0-based indexing
267+
# ensure cohorts is a numpy array to minimize dask task
268+
# dependencies between chunks in other dimensions
269+
AC, SC = da.asarray(ds[call_allele_count]), ds[sample_cohort].values
270+
n_cohorts = SC.max() + 1 # 0-based indexing
268271
AC = cohort_sum(AC, SC, n_cohorts, axis=1)
269272
new_ds = create_dataset(
270273
{variables.cohort_allele_count: (("variants", "cohorts", "alleles"), AC)}

sgkit/stats/popgen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1018,8 +1018,10 @@ def observed_heterozygosity(
10181018
)
10191019
variables.validate(ds, {call_heterozygosity: variables.call_heterozygosity_spec})
10201020
hi = da.asarray(ds[call_heterozygosity])
1021-
cohort = da.asarray(ds[sample_cohort])
1022-
n_cohorts = cohort.max().compute() + 1
1021+
# ensure cohorts is a numpy array to minimize dask task
1022+
# dependencies between chunks in other dimensions
1023+
cohort = ds[sample_cohort].values
1024+
n_cohorts = cohort.max() + 1
10231025
ho = cohort_nanmean(hi, cohort, n_cohorts)
10241026
if has_windows(ds):
10251027
ho_sum = window_statistic(

0 commit comments

Comments
 (0)