Skip to content

Toggle numba caching by environment variable #870

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions sgkit/accelerate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
from typing import Callable

from numba import guvectorize, jit

_DISABLE_CACHE = os.environ.get("SGKIT_DISABLE_NUMBA_CACHE", "0")

try:
CACHE_NUMBA = {"0": True, "1": False}[_DISABLE_CACHE]
except KeyError as e: # pragma: no cover
raise KeyError(
"Environment variable 'SGKIT_DISABLE_NUMBA_CACHE' must be '0' or '1'"
) from e


DEFAULT_NUMBA_ARGS = {
"nopython": True,
"cache": CACHE_NUMBA,
}


def numba_jit(*args, **kwargs) -> Callable: # pragma: no cover
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
kwargs_.update(kwargs)
return jit(*args, **kwargs_)


def numba_guvectorize(*args, **kwargs) -> Callable: # pragma: no cover
kwargs_ = DEFAULT_NUMBA_ARGS.copy()
kwargs_.update(kwargs)
return guvectorize(*args, **kwargs_)
15 changes: 5 additions & 10 deletions sgkit/distance/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from typing import Any

import numpy as np
from numba import cuda, guvectorize, types
from numba import cuda, types

from sgkit.accelerate import numba_guvectorize
from sgkit.typing import ArrayLike

# The number of parameters for the map step of the respective distance metric
Expand All @@ -20,15 +21,13 @@
}


@guvectorize( # type: ignore
@numba_guvectorize( # type: ignore
[
"void(float32[:], float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], int8[:], float64[:])",
],
"(n),(n),(p)->(p)",
nopython=True,
cache=True,
)
def euclidean_map_cpu(
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
Expand Down Expand Up @@ -78,15 +77,13 @@ def euclidean_reduce_cpu(v: ArrayLike) -> ArrayLike: # pragma: no cover
return out


@guvectorize( # type: ignore
@numba_guvectorize( # type: ignore
[
"void(float32[:], float32[:], float32[:], float32[:])",
"void(float64[:], float64[:], float64[:], float64[:])",
"void(int8[:], int8[:], int8[:], float64[:])",
],
"(n),(n),(p)->(p)",
nopython=True,
cache=True,
)
def correlation_map_cpu(
x: ArrayLike, y: ArrayLike, _: ArrayLike, out: ArrayLike
Expand Down Expand Up @@ -141,14 +138,12 @@ def correlation_map_cpu(
)


@guvectorize( # type: ignore
@numba_guvectorize( # type: ignore
[
"void(float32[:, :], float32[:])",
"void(float64[:, :], float64[:])",
],
"(p, m)->()",
nopython=True,
cache=True,
)
def correlation_reduce_cpu(v: ArrayLike, out: ArrayLike) -> None: # pragma: no cover
"""Corresponding "reduce" function for pearson correlation
Expand Down
6 changes: 2 additions & 4 deletions sgkit/stats/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import dask.array as da
import numpy as np
import xarray as xr
from numba import guvectorize
from typing_extensions import Literal
from xarray import Dataset

from sgkit import variables
from sgkit.accelerate import numba_guvectorize
from sgkit.stats.utils import cohort_sum
from sgkit.typing import ArrayLike
from sgkit.utils import (
Expand All @@ -19,16 +19,14 @@
Dimension = Literal["samples", "variants"]


@guvectorize( # type: ignore
@numba_guvectorize( # type: ignore
[
"void(int8[:], uint8[:], uint8[:])",
"void(int16[:], uint8[:], uint8[:])",
"void(int32[:], uint8[:], uint8[:])",
"void(int64[:], uint8[:], uint8[:])",
],
"(k),(n)->(n)",
nopython=True,
cache=True,
)
def count_alleles(
g: ArrayLike, _: ArrayLike, out: ArrayLike
Expand Down
6 changes: 2 additions & 4 deletions sgkit/stats/conversion.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
import dask.array as da
import numpy as np
from numba import guvectorize
from xarray import Dataset

from sgkit import variables
from sgkit.accelerate import numba_guvectorize
from sgkit.typing import ArrayLike
from sgkit.utils import conditional_merge_datasets, create_dataset


@guvectorize( # type: ignore
@numba_guvectorize( # type: ignore
[
"void(float64[:], uint8[:], float64, int8[:])",
"void(float32[:], uint8[:], float64, int8[:])",
],
"(p),(k),()->(k)",
nopython=True,
cache=True,
)
def _convert_probability_to_call(
gp: ArrayLike, _: ArrayLike, threshold: float, out: ArrayLike
Expand Down
8 changes: 5 additions & 3 deletions sgkit/stats/hwe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import dask.array as da
import numpy as np
from numba import njit
from xarray import Dataset

from sgkit import variables
from sgkit.accelerate import numba_jit
from sgkit.stats.aggregation import count_genotypes
from sgkit.typing import NDArray
from sgkit.utils import conditional_merge_datasets, create_dataset
Expand Down Expand Up @@ -100,7 +100,9 @@ def hardy_weinberg_p_value(obs_hets: int, obs_hom1: int, obs_hom2: int) -> float


# Benchmarks show ~25% improvement w/ fastmath on large (~10M) counts
hardy_weinberg_p_value_jit = njit(hardy_weinberg_p_value, fastmath=True, nogil=True)
hardy_weinberg_p_value_jit = numba_jit(
hardy_weinberg_p_value, fastmath=True, nogil=True
)


def hardy_weinberg_p_value_vec(
Expand All @@ -118,7 +120,7 @@ def hardy_weinberg_p_value_vec(
return p


hardy_weinberg_p_value_vec_jit = njit(
hardy_weinberg_p_value_vec_jit = numba_jit(
hardy_weinberg_p_value_vec, fastmath=True, nogil=True
)

Expand Down
10 changes: 5 additions & 5 deletions sgkit/stats/ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
import numpy as np
import pandas as pd
from dask.dataframe import DataFrame
from numba import njit
from xarray import Dataset

from sgkit import variables
from sgkit.accelerate import numba_jit
from sgkit.typing import ArrayLike, DType
from sgkit.window import _get_chunked_windows, _sizes_to_start_offsets, has_windows


@njit(nogil=True, fastmath=False, cache=True) # type: ignore
@numba_jit(nogil=True, fastmath=False) # type: ignore
def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
"""Rogers Huff *r*.

Expand Down Expand Up @@ -67,7 +67,7 @@ def rogers_huff_r_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: n
return r


@njit(nogil=True, fastmath=True, cache=True) # type: ignore
@numba_jit(nogil=True, fastmath=True) # type: ignore
def rogers_huff_r2_between(gn0: ArrayLike, gn1: ArrayLike) -> float: # pragma: no cover
return rogers_huff_r_between(gn0, gn1) ** 2 # type: ignore

Expand Down Expand Up @@ -202,7 +202,7 @@ def to_ld_df(x: ArrayLike, chunk_index: int) -> DataFrame:
)


@njit(nogil=True, cache=True) # type: ignore
@numba_jit(nogil=True) # type: ignore
def _ld_matrix_jit(
x: ArrayLike,
chunk_window_starts: ArrayLike,
Expand Down Expand Up @@ -302,7 +302,7 @@ def _ld_matrix(
return df


@njit(nogil=True, cache=True) # type: ignore
@numba_jit(nogil=True) # type: ignore
def _maximal_independent_set_jit(
idi: ArrayLike, idj: ArrayLike, cmp: ArrayLike
) -> List[int]: # pragma: no cover
Expand Down
Loading