Skip to content

Commit 64638f5

Browse files
committed
Toggle numba caching by environment variable
1 parent c13acb1 commit 64638f5

File tree

8 files changed

+51
-32
lines changed

8 files changed

+51
-32
lines changed

sgkit/caching.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import os
2+
3+
_DISABLE_CACHE = os.environ.get("SGKIT_DISABLE_NUMBA_CACHE", "0")
4+
5+
try:
6+
CACHE_NUMBA = {"0": True, "1": False}[_DISABLE_CACHE]
7+
except KeyError as e: # pragma: no cover
8+
raise KeyError(
9+
"Environment variable 'SGKIT_DISABLE_NUMBA_CACHE' must be '0' or '1'"
10+
) from e

sgkit/stats/aggregation.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from xarray import Dataset
99

1010
from sgkit import variables
11+
from sgkit.caching import CACHE_NUMBA
1112
from sgkit.stats.utils import cohort_sum
1213
from sgkit.typing import ArrayLike
1314
from sgkit.utils import (
@@ -28,7 +29,7 @@
2829
],
2930
"(k),(n)->(n)",
3031
nopython=True,
31-
cache=True,
32+
cache=CACHE_NUMBA,
3233
)
3334
def count_alleles(
3435
g: ArrayLike, _: ArrayLike, out: ArrayLike

sgkit/stats/conversion.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from xarray import Dataset
55

66
from sgkit import variables
7+
from sgkit.caching import CACHE_NUMBA
78
from sgkit.typing import ArrayLike
89
from sgkit.utils import conditional_merge_datasets, create_dataset
910

@@ -15,7 +16,7 @@
1516
],
1617
"(p),(k),()->(k)",
1718
nopython=True,
18-
cache=True,
19+
cache=CACHE_NUMBA,
1920
)
2021
def _convert_probability_to_call(
2122
gp: ArrayLike, _: ArrayLike, threshold: float, out: ArrayLike

sgkit/stats/ld.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from xarray import Dataset
1212

1313
from sgkit import variables
14+
from sgkit.caching import CACHE_NUMBA
1415
from sgkit.typing import ArrayLike, DType
1516
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows
1617

1718

18-
@njit(nogil=True, fastmath=False, cache=True) # type: ignore
19+
@njit(nogil=True, fastmath=False, cache=CACHE_NUMBA) # type: ignore
1920
def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
2021
"""Rogers Huff *r*.
2122
@@ -67,7 +68,7 @@ def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: n
6768
return r
6869

6970

70-
@njit(nogil=True, fastmath=True, cache=True) # type: ignore
71+
@njit(nogil=True, fastmath=True, cache=CACHE_NUMBA) # type: ignore
7172
def rogers_huff_r2_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
7273
return rogers_huff_r_between(gn0, gn1) ** 2 # type: ignore
7374

@@ -202,7 +203,7 @@ def to_ld_df(x: ArrayLike, chunk_index: int) -> DataFrame:
202203
)
203204

204205

205-
@njit(nogil=True, cache=True) # type: ignore
206+
@njit(nogil=True, cache=CACHE_NUMBA) # type: ignore
206207
def _ld_matrix_jit(
207208
x: ArrayLike,
208209
chunk_window_starts: ArrayLike,
@@ -302,7 +303,7 @@ def _ld_matrix(
302303
return df
303304

304305

305-
@njit(nogil=True, cache=True) # type: ignore
306+
@njit(nogil=True, cache=CACHE_NUMBA) # type: ignore
306307
def _maximal_independent_set_jit(
307308
idi: ArrayLike, idj: ArrayLike, cmp: ArrayLike
308309
) -> List[int]: # pragma: no cover

sgkit/stats/pedigree.py

+17-16
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from xarray import Dataset
88

99
from sgkit import variables
10+
from sgkit.caching import CACHE_NUMBA
1011
from sgkit.typing import ArrayLike
1112
from sgkit.utils import (
1213
conditional_merge_datasets,
@@ -108,7 +109,7 @@ def parent_indices(
108109
return conditional_merge_datasets(ds, new_ds, merge)
109110

110111

111-
@njit(cache=True)
112+
@njit(cache=CACHE_NUMBA)
112113
def topological_argsort(parent: ArrayLike) -> ArrayLike: # pragma: no cover
113114
"""Find a topological ordering of samples within a pedigree such
114115
that no individual occurs before its parents.
@@ -172,7 +173,7 @@ def topological_argsort(parent: ArrayLike) -> ArrayLike: # pragma: no cover
172173
return order[::-1]
173174

174175

175-
@njit(cache=True)
176+
@njit(cache=CACHE_NUMBA)
176177
def _is_pedigree_sorted(parent: ArrayLike) -> bool: # pragma: no cover
177178
n_samples, n_parents = parent.shape
178179
for i in range(n_samples):
@@ -183,7 +184,7 @@ def _is_pedigree_sorted(parent: ArrayLike) -> bool: # pragma: no cover
183184
return True
184185

185186

186-
@njit(cache=True)
187+
@njit(cache=CACHE_NUMBA)
187188
def _raise_on_half_founder(
188189
parent: ArrayLike, tau: ArrayLike = None
189190
) -> None: # pragma: no cover
@@ -202,7 +203,7 @@ def _raise_on_half_founder(
202203
raise ValueError("Pedigree contains half-founders")
203204

204205

205-
@njit(cache=True)
206+
@njit(cache=CACHE_NUMBA)
206207
def _diploid_self_kinship(
207208
kinship: ArrayLike, parent: ArrayLike, i: int
208209
) -> None: # pragma: no cover
@@ -214,7 +215,7 @@ def _diploid_self_kinship(
214215
kinship[i, i] = (1 + kinship[p, q]) / 2
215216

216217

217-
@njit(cache=True)
218+
@njit(cache=CACHE_NUMBA)
218219
def _diploid_pair_kinship(
219220
kinship: ArrayLike, parent: ArrayLike, i: int, j: int
220221
) -> None: # pragma: no cover
@@ -227,7 +228,7 @@ def _diploid_pair_kinship(
227228
kinship[j, i] = kinship_ij
228229

229230

230-
@njit(cache=True)
231+
@njit(cache=CACHE_NUMBA)
231232
def kinship_diploid(
232233
parent: ArrayLike, allow_half_founders: bool = False, dtype: type = np.float64
233234
) -> ArrayLike: # pragma: no cover
@@ -290,15 +291,15 @@ def kinship_diploid(
290291
return kinship
291292

292293

293-
@njit(cache=True)
294+
@njit(cache=CACHE_NUMBA)
294295
def _inbreeding_as_self_kinship(
295296
inbreeding: float, ploidy: int
296297
) -> float: # pragma: no cover
297298
"""Calculate self-kinship of an individual."""
298299
return (1 + (ploidy - 1) * inbreeding) / ploidy
299300

300301

301-
@njit(cache=True)
302+
@njit(cache=CACHE_NUMBA)
302303
def _hamilton_kerr_inbreeding_founder(
303304
lambda_p: float, lambda_q: float, ploidy_i: int
304305
) -> float: # pragma: no cover
@@ -310,7 +311,7 @@ def _hamilton_kerr_inbreeding_founder(
310311
return num / denom
311312

312313

313-
@njit(cache=True)
314+
@njit(cache=CACHE_NUMBA)
314315
def _hamilton_kerr_inbreeding_non_founder(
315316
tau_p: int,
316317
lambda_p: float,
@@ -340,7 +341,7 @@ def _hamilton_kerr_inbreeding_non_founder(
340341
return num / denom
341342

342343

343-
@njit(cache=True)
344+
@njit(cache=CACHE_NUMBA)
344345
def _hamilton_kerr_inbreeding_half_founder(
345346
tau_p: int,
346347
lambda_p: float,
@@ -374,7 +375,7 @@ def _hamilton_kerr_inbreeding_half_founder(
374375
)
375376

376377

377-
@njit(cache=True)
378+
@njit(cache=CACHE_NUMBA)
378379
def _hamilton_kerr_self_kinship(
379380
kinship: ArrayLike, parent: ArrayLike, tau: ArrayLike, lambda_: ArrayLike, i: int
380381
) -> None: # pragma: no cover
@@ -421,7 +422,7 @@ def _hamilton_kerr_self_kinship(
421422
kinship[i, i] = _inbreeding_as_self_kinship(inbreeding_i, ploidy_i)
422423

423424

424-
@njit(cache=True)
425+
@njit(cache=CACHE_NUMBA)
425426
def _hamilton_kerr_pair_kinship(
426427
kinship: ArrayLike, parent: ArrayLike, tau: ArrayLike, i: int, j: int
427428
) -> None: # pragma: no cover
@@ -435,7 +436,7 @@ def _hamilton_kerr_pair_kinship(
435436
kinship[j, i] = kinship_ij
436437

437438

438-
@njit(cache=True)
439+
@njit(cache=CACHE_NUMBA)
439440
def kinship_Hamilton_Kerr(
440441
parent: ArrayLike,
441442
tau: ArrayLike,
@@ -646,7 +647,7 @@ def pedigree_kinship(
646647
return conditional_merge_datasets(ds, new_ds, merge)
647648

648649

649-
@vectorize(nopython=True, cache=True)
650+
@vectorize(nopython=True, cache=CACHE_NUMBA)
650651
def kinship_as_additive_relationship(
651652
kinship: float, ploidy_x: int, ploidy_y: int
652653
) -> float: # pragma: no cover
@@ -783,7 +784,7 @@ def additive_relationships(
783784
return conditional_merge_datasets(ds, new_ds, merge)
784785

785786

786-
@njit(cache=True)
787+
@njit(cache=CACHE_NUMBA)
787788
def _update_inverse_additive_relationships(
788789
mtx: ArrayLike,
789790
kinship: ArrayLike,
@@ -838,7 +839,7 @@ def _update_inverse_additive_relationships(
838839
mtx[i, i] += scalar / ploidy_i
839840

840841

841-
@njit(cache=True)
842+
@njit(cache=CACHE_NUMBA)
842843
def pedigree_kinships_as_inverse_additive_relationships(
843844
kinship: ArrayLike, parent: ArrayLike, tau: Union[ArrayLike, None] = None
844845
) -> ArrayLike: # pragma: no cover

sgkit/stats/popgen.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numba import guvectorize
88
from xarray import Dataset
99

10+
from sgkit.caching import CACHE_NUMBA
1011
from sgkit.cohorts import _cohorts_to_array
1112
from sgkit.stats.utils import assert_array_shape
1213
from sgkit.typing import ArrayLike
@@ -137,7 +138,7 @@ def diversity(
137138
["void(int64[:, :], float64[:,:])", "void(uint64[:, :], float64[:,:])"],
138139
"(c, k)->(c,c)",
139140
nopython=True,
140-
cache=True,
141+
cache=CACHE_NUMBA,
141142
)
142143
def _divergence(ac: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
143144
"""Generalized U-function for computing divergence.
@@ -310,7 +311,7 @@ def divergence(
310311
],
311312
"(c,c)->(c,c)",
312313
nopython=True,
313-
cache=True,
314+
cache=CACHE_NUMBA,
314315
)
315316
def _Fst_Hudson(d: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
316317
"""Generalized U-function for computing Fst using Hudson's estimator.
@@ -342,7 +343,7 @@ def _Fst_Hudson(d: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
342343
],
343344
"(c,c)->(c,c)",
344345
nopython=True,
345-
cache=True,
346+
cache=CACHE_NUMBA,
346347
)
347348
def _Fst_Nei(d: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
348349
"""Generalized U-function for computing Fst using Nei's estimator.
@@ -623,7 +624,7 @@ def Tajimas_D(
623624
["void(float32[:, :], float32[:,:,:])", "void(float64[:, :], float64[:,:,:])"],
624625
"(c,c)->(c,c,c)",
625626
nopython=True,
626-
cache=True,
627+
cache=CACHE_NUMBA,
627628
)
628629
def _pbs(t: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
629630
"""Generalized U-function for computing PBS."""
@@ -647,7 +648,7 @@ def _pbs(t: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
647648
],
648649
"(c,c),(ct,i)->(c,c,c)",
649650
nopython=True,
650-
cache=True,
651+
cache=CACHE_NUMBA,
651652
)
652653
def _pbs_cohorts(
653654
t: ArrayLike, ct: ArrayLike, out: ArrayLike

sgkit/stats/utils.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from numba import guvectorize
99
from xarray import DataArray, Dataset
1010

11+
from sgkit.caching import CACHE_NUMBA
12+
1113
from ..typing import ArrayLike
1214

1315

@@ -176,7 +178,7 @@ def func(x: ArrayLike, cohort: ArrayLike, n: int, axis: int = -1) -> ArrayLike:
176178
],
177179
"(n),(n),(c)->(c)",
178180
nopython=True,
179-
cache=True,
181+
cache=CACHE_NUMBA,
180182
)
181183
def cohort_sum(
182184
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
@@ -222,7 +224,7 @@ def cohort_sum(
222224
],
223225
"(n),(n),(c)->(c)",
224226
nopython=True,
225-
cache=True,
227+
cache=CACHE_NUMBA,
226228
)
227229
def cohort_nansum(
228230
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
@@ -269,7 +271,7 @@ def cohort_nansum(
269271
],
270272
"(n),(n),(c)->(c)",
271273
nopython=True,
272-
cache=True,
274+
cache=CACHE_NUMBA,
273275
)
274276
def cohort_mean(
275277
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike
@@ -320,7 +322,7 @@ def cohort_mean(
320322
],
321323
"(n),(n),(c)->(c)",
322324
nopython=True,
323-
cache=True,
325+
cache=CACHE_NUMBA,
324326
)
325327
def cohort_nanmean(
326328
x: ArrayLike, cohort: ArrayLike, _: ArrayLike, out: ArrayLike

sgkit/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from numba import guvectorize
66
from xarray import Dataset
77

8+
from sgkit.caching import CACHE_NUMBA
9+
810
from . import variables
911
from .typing import ArrayLike, DType
1012

@@ -317,7 +319,7 @@ def max_str_len(a: ArrayLike) -> ArrayLike:
317319
],
318320
"(n)->()",
319321
nopython=True,
320-
cache=True,
322+
cache=CACHE_NUMBA,
321323
)
322324
def hash_array(x: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
323325
"""Hash entries of ``x`` using the DJBX33A hash function.

0 commit comments

Comments
 (0)