Skip to content

Commit 6b3d495

Browse files
timothymillartomwhite
authored andcommitted
Use wrapped versions of jit and guvectorize with custom defaults
1 parent 3438d64 commit 6b3d495

File tree

10 files changed

+82
-108
lines changed

10 files changed

+82
-108
lines changed

sgkit/accelerate.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os
2+
from typing import Callable
3+
4+
from numba import guvectorize, jit
5+
6+
_DISABLE_CACHE = os.environ.get("SGKIT_DISABLE_NUMBA_CACHE", "0")
7+
8+
try:
9+
CACHE_NUMBA = {"0": True, "1": False}[_DISABLE_CACHE]
10+
except KeyError as e: # pragma: no cover
11+
raise KeyError(
12+
"Environment variable 'SGKIT_DISABLE_NUMBA_CACHE' must be '0' or '1'"
13+
) from e
14+
15+
16+
DEFAULT_NUMBA_ARGS = {
17+
"nopython": True,
18+
"cache": CACHE_NUMBA,
19+
}
20+
21+
22+
def numba_jit(*args, **kwargs) -> Callable: # pragma: no cover
23+
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
24+
kwargs_.update(kwargs)
25+
return jit(*args, **kwargs_)
26+
27+
28+
def numba_guvectorize(*args, **kwargs) -> Callable: # pragma: no cover
29+
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
30+
kwargs_.update(kwargs)
31+
return guvectorize(*args, **kwargs_)

sgkit/distance/metrics.py

+5-10
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99
from typing import Any
1010

1111
import numpy as np
12-
from numba import cuda, guvectorize, types
12+
from numba import cuda, types
1313

14+
from sgkit.accelerate import numba_guvectorize
1415
from sgkit.typing import ArrayLike
1516

1617
# The number of parameters for the map step of the respective distance metric
@@ -20,15 +21,13 @@
2021
}
2122

2223

23-
@guvectorize( # type: ignore
24+
@numba_guvectorize( # type: ignore
2425
[
2526
"void(float32[:], float32[:], float32[:], float32[:])",
2627
"void(float64[:], float64[:], float64[:], float64[:])",
2728
"void(int8[:], int8[:], int8[:], float64[:])",
2829
],
2930
"(n),(n),(p)->(p)",
30-
nopython=True,
31-
cache=True,
3231
)
3332
def euclidean_map_cpu(
3433
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
@@ -78,15 +77,13 @@ def euclidean_reduce_cpu(v: ArrayLike) -> ArrayLike: # pragma: no cover
7877
return out
7978

8079

81-
@guvectorize( # type: ignore
80+
@numba_guvectorize( # type: ignore
8281
[
8382
"void(float32[:], float32[:], float32[:], float32[:])",
8483
"void(float64[:], float64[:], float64[:], float64[:])",
8584
"void(int8[:], int8[:], int8[:], float64[:])",
8685
],
8786
"(n),(n),(p)->(p)",
88-
nopython=True,
89-
cache=True,
9087
)
9188
def correlation_map_cpu(
9289
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
@@ -141,14 +138,12 @@ def correlation_map_cpu(
141138
)
142139

143140

144-
@guvectorize( # type: ignore
141+
@numba_guvectorize( # type: ignore
145142
[
146143
"void(float32[:, :], float32[:])",
147144
"void(float64[:, :], float64[:])",
148145
],
149146
"(p, m)->()",
150-
nopython=True,
151-
cache=True,
152147
)
153148
def correlation_reduce_cpu(v: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
154149
"""Corresponding "reduce" function for pearson correlation

sgkit/stats/aggregation.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import dask.array as da
44
import numpy as np
55
import xarray as xr
6-
from numba import guvectorize
76
from typing_extensions import Literal
87
from xarray import Dataset
98

109
from sgkit import variables
10+
from sgkit.accelerate import numba_guvectorize
1111
from sgkit.stats.utils import cohort_sum
1212
from sgkit.typing import ArrayLike
1313
from sgkit.utils import (
@@ -19,16 +19,14 @@
1919
Dimension = Literal["samples", "variants"]
2020

2121

22-
@guvectorize( # type: ignore
22+
@numba_guvectorize( # type: ignore
2323
[
2424
"void(int8[:], uint8[:], uint8[:])",
2525
"void(int16[:], uint8[:], uint8[:])",
2626
"void(int32[:], uint8[:], uint8[:])",
2727
"void(int64[:], uint8[:], uint8[:])",
2828
],
2929
"(k),(n)->(n)",
30-
nopython=True,
31-
cache=True,
3230
)
3331
def count_alleles(
3432
g: ArrayLike, _: ArrayLike, out: ArrayLike

sgkit/stats/conversion.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
import dask.array as da
22
import numpy as np
3-
from numba import guvectorize
43
from xarray import Dataset
54

65
from sgkit import variables
6+
from sgkit.accelerate import numba_guvectorize
77
from sgkit.typing import ArrayLike
88
from sgkit.utils import conditional_merge_datasets, create_dataset
99

1010

11-
@guvectorize( # type: ignore
11+
@numba_guvectorize( # type: ignore
1212
[
1313
"void(float64[:], uint8[:], float64, int8[:])",
1414
"void(float32[:], uint8[:], float64, int8[:])",
1515
],
1616
"(p),(k),()->(k)",
17-
nopython=True,
18-
cache=True,
1917
)
2018
def _convert_probability_to_call(
2119
gp: ArrayLike, _: ArrayLike, threshold: float, out: ArrayLike

sgkit/stats/hwe.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22

33
import dask.array as da
44
import numpy as np
5-
from numba import njit
65
from xarray import Dataset
76

87
from sgkit import variables
8+
from sgkit.accelerate import numba_jit
99
from sgkit.stats.aggregation import count_genotypes
1010
from sgkit.typing import NDArray
1111
from sgkit.utils import conditional_merge_datasets, create_dataset
@@ -100,7 +100,9 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float
100100

101101

102102
# Benchmarks show ~25% improvement w/ fastmath on large (~10M) counts
103-
hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value, fastmath=True, nogil=True)
103+
hardy_weinberg_p_value_jit = numba_jit(
104+
hardy_weinberg_p_value, fastmath=True, nogil=True
105+
)
104106

105107

106108
def hardy_weinberg_p_value_vec(
@@ -118,7 +120,7 @@ def hardy_weinberg_p_value_vec(
118120
return p
119121

120122

121-
hardy_weinberg_p_value_vec_jit = njit(
123+
hardy_weinberg_p_value_vec_jit = numba_jit(
122124
hardy_weinberg_p_value_vec, fastmath=True, nogil=True
123125
)
124126

sgkit/stats/ld.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
import numpy as np
88
import pandas as pd
99
from dask.dataframe import DataFrame
10-
from numba import njit
1110
from xarray import Dataset
1211

1312
from sgkit import variables
13+
from sgkit.accelerate import numba_jit
1414
from sgkit.typing import ArrayLike, DType
1515
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows
1616

1717

18-
@njit(nogil=True, fastmath=False, cache=True) # type: ignore
18+
@numba_jit(nogil=True, fastmath=False) # type: ignore
1919
def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
2020
"""Rogers Huff *r*.
2121
@@ -67,7 +67,7 @@ def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: n
6767
return r
6868

6969

70-
@njit(nogil=True, fastmath=True, cache=True) # type: ignore
70+
@numba_jit(nogil=True, fastmath=True) # type: ignore
7171
def rogers_huff_r2_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
7272
return rogers_huff_r_between(gn0, gn1) ** 2 # type: ignore
7373

@@ -202,7 +202,7 @@ def to_ld_df(x: ArrayLike, chunk_index: int) -> DataFrame:
202202
)
203203

204204

205-
@njit(nogil=True, cache=True) # type: ignore
205+
@numba_jit(nogil=True) # type: ignore
206206
def _ld_matrix_jit(
207207
x: ArrayLike,
208208
chunk_window_starts: ArrayLike,
@@ -302,7 +302,7 @@ def _ld_matrix(
302302
return df
303303

304304

305-
@njit(nogil=True, cache=True) # type: ignore
305+
@numba_jit(nogil=True) # type: ignore
306306
def _maximal_independent_set_jit(
307307
idi: ArrayLike, idj: ArrayLike, cmp: ArrayLike
308308
) -> List[int]: # pragma: no cover

0 commit comments

Comments
 (0)