Skip to content

ENH: Add numba engine to df.apply #55104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fa802c
ENH: Add numba engine to df.apply
lithomas1 Sep 12, 2023
c6af7c9
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 12, 2023
0ac544d
complete?
lithomas1 Sep 14, 2023
31b9e20
wip: pass tests
lithomas1 Sep 19, 2023
6190772
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 24, 2023
55df7ad
fix existing tests
lithomas1 Sep 24, 2023
3c89b0f
go for green
lithomas1 Sep 25, 2023
1418d3e
fix checks?
lithomas1 Sep 25, 2023
c143c67
fix pyright
lithomas1 Sep 25, 2023
0d827c4
update docs
lithomas1 Sep 28, 2023
7129ee8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 28, 2023
b0ba283
Merge branch 'main' into numba-apply
lithomas1 Sep 29, 2023
f4e80a6
eliminate a blank line
lithomas1 Sep 29, 2023
21e2186
update from code review + more tests
lithomas1 Oct 7, 2023
b60bef8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 9, 2023
ba1d0e0
fix failing tests
lithomas1 Oct 10, 2023
088d27f
Simplify w/ context manager
lithomas1 Oct 12, 2023
60539a1
skip if no numba
lithomas1 Oct 12, 2023
76538d6
simplify more
lithomas1 Oct 12, 2023
cca34f9
specify dtypes
lithomas1 Oct 12, 2023
8b423bf
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 15, 2023
f135def
Merge branch 'main' into numba-apply
lithomas1 Oct 16, 2023
b2e50d2
Merge branch 'numba-apply' of github.com:lithomas1/pandas into numba-…
lithomas1 Oct 16, 2023
f86024f
address code review
lithomas1 Oct 16, 2023
a15293d
add errors for invalid columns
lithomas1 Oct 19, 2023
8fe5d89
adjust message
lithomas1 Oct 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
575 changes: 575 additions & 0 deletions pandas/core/_numba/extensions.py

Large diffs are not rendered by default.

184 changes: 175 additions & 9 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
from collections import defaultdict
import functools
from functools import partial
import inspect
from typing import (
Expand Down Expand Up @@ -29,14 +30,17 @@
NDFrameT,
npt,
)
from pandas.compat._optional import import_optional_dependency
from pandas.errors import SpecificationError
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level

from pandas.core.dtypes.cast import is_nested_object
from pandas.core.dtypes.common import (
is_dict_like,
is_extension_array_dtype,
is_list_like,
is_numeric_dtype,
is_sequence,
)
from pandas.core.dtypes.dtypes import (
Expand Down Expand Up @@ -121,6 +125,8 @@ def __init__(
result_type: str | None,
*,
by_row: Literal[False, "compat", "_compat"] = "compat",
engine: str = "python",
engine_kwargs: dict[str, bool] | None = None,
args,
kwargs,
) -> None:
Expand All @@ -133,6 +139,9 @@ def __init__(
self.args = args or ()
self.kwargs = kwargs or {}

self.engine = engine
self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs

if result_type not in [None, "reduce", "broadcast", "expand"]:
raise ValueError(
"invalid value for result_type, must be one "
Expand Down Expand Up @@ -601,6 +610,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
result: Series, DataFrame, or None
Result when self.func is a list-like or dict-like, None otherwise.
"""

if self.engine == "numba":
raise NotImplementedError(
"The 'numba' engine doesn't support list-like/"
"dict likes of callables yet."
)

if self.axis == 1 and isinstance(self.obj, ABCDataFrame):
return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T

Expand Down Expand Up @@ -768,10 +784,16 @@ def __init__(
) -> None:
if by_row is not False and by_row != "compat":
raise ValueError(f"by_row={by_row} not allowed")
self.engine = engine
self.engine_kwargs = engine_kwargs
super().__init__(
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
obj,
func,
raw,
result_type,
by_row=by_row,
engine=engine,
engine_kwargs=engine_kwargs,
args=args,
kwargs=kwargs,
)

# ---------------------------------------------------------------
Expand All @@ -792,6 +814,32 @@ def result_columns(self) -> Index:
def series_generator(self) -> Generator[Series, None, None]:
pass

@staticmethod
@functools.cache
@abc.abstractmethod
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
pass

@abc.abstractmethod
def apply_with_numba(self):
pass

def validate_values_for_numba(self):
# Validate column dtyps all OK
for colname, dtype in self.obj.dtypes.items():
if not is_numeric_dtype(dtype):
raise ValueError(
f"Column {colname} must have a numeric dtype. "
f"Found '{dtype}' instead"
)
if is_extension_array_dtype(dtype):
raise ValueError(
f"Column {colname} is backed by an extension array, "
f"which is not supported by the numba engine."
)

@abc.abstractmethod
def wrap_results_for_axis(
self, results: ResType, res_index: Index
Expand All @@ -815,13 +863,12 @@ def values(self):
def apply(self) -> DataFrame | Series:
"""compute the results"""

if self.engine == "numba" and not self.raw:
raise ValueError(
"The numba engine in DataFrame.apply can only be used when raw=True"
)

# dispatch to handle list-like or dict-like
if is_list_like(self.func):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support lists of callables yet"
)
return self.apply_list_or_dict_like()

# all empty
Expand All @@ -830,17 +877,31 @@ def apply(self) -> DataFrame | Series:

# string dispatch
if isinstance(self.func, str):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support using "
"a string as the callable function"
)
return self.apply_str()

# ufunc
elif isinstance(self.func, np.ufunc):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support "
"using a numpy ufunc as the callable function"
)
with np.errstate(all="ignore"):
results = self.obj._mgr.apply("apply", func=self.func)
# _constructor will retain self.index and self.columns
return self.obj._constructor_from_mgr(results, axes=results.axes)

# broadcasting
if self.result_type == "broadcast":
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support result_type='broadcast'"
)
return self.apply_broadcast(self.obj)

# one axis empty
Expand Down Expand Up @@ -997,7 +1058,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
return result

def apply_standard(self):
results, res_index = self.apply_series_generator()
if self.engine == "python":
results, res_index = self.apply_series_generator()
else:
results, res_index = self.apply_series_numba()

# wrap results
return self.wrap_results(results, res_index)
Expand All @@ -1021,6 +1085,19 @@ def apply_series_generator(self) -> tuple[ResType, Index]:

return results, res_index

def apply_series_numba(self):
if self.engine_kwargs.get("parallel", False):
raise NotImplementedError(
"Parallel apply is not supported when raw=False and engine='numba'"
)
if not self.obj.index.is_unique or not self.columns.is_unique:
raise NotImplementedError(
"The index/columns must be unique when raw=False and engine='numba'"
)
self.validate_values_for_numba()
results = self.apply_with_numba()
return results, self.result_index

def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series:
from pandas import Series

Expand Down Expand Up @@ -1060,6 +1137,49 @@ class FrameRowApply(FrameApply):
def series_generator(self) -> Generator[Series, None, None]:
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
numba = import_optional_dependency("numba")
from pandas import Series

# Import helper from extensions to cast string object -> np strings
# Note: This also has the side effect of loading our numba extensions
from pandas.core._numba.extensions import maybe_cast_str

jitted_udf = numba.extending.register_jitable(func)

# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
results = {}
for j in range(values.shape[1]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for j in range(values.shape[1]):
for j in numba.prange(values.shape[1]):

? (and below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out!

I think for now it'll probably make better sense to disable parallel mode for now, since the dict in numba isn't thread-safe.

The overhead from the boxing/unboxing is also really high (99% of the time spent is there), so I doubt parallel will give a good speedup, at least for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense. Would be good to put a TODO: comment explaining why we shouldn't use prange for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment.

# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)
from pandas.core._numba.extensions import set_numba_data

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(self.obj.index) as index, set_numba_data(
self.columns
) as columns:
res = dict(nb_func(self.values, columns, index))
return res

@property
def result_index(self) -> Index:
return self.columns
Expand Down Expand Up @@ -1143,6 +1263,52 @@ def series_generator(self) -> Generator[Series, None, None]:
object.__setattr__(ser, "_name", name)
yield ser

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
numba = import_optional_dependency("numba")
from pandas import Series
from pandas.core._numba.extensions import maybe_cast_str

jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
results = {}
# Currently the parallel argument doesn't get passed through here
# (it's disabled) since the dicts in numba aren't thread-safe.
for i in range(values.shape[0]):
# Create the series
# TODO: values corrupted without the copy
ser = Series(
values[i].copy(),
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)

from pandas.core._numba.extensions import set_numba_data

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
with set_numba_data(self.obj.index) as index, set_numba_data(
self.columns
) as columns:
res = dict(nb_func(self.values, columns, index))

return res

@property
def result_index(self) -> Index:
return self.index
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10092,6 +10092,9 @@ def apply(
- nogil (release the GIL inside the JIT compiled function)
- parallel (try to apply the function in parallel over the DataFrame)

Note: Due to limitations within numba/how pandas interfaces with numba,
you should only use this if raw=True

Note: The numba compiler only supports a subset of
valid Python/numpy operations.

Expand All @@ -10101,8 +10104,6 @@ def apply(
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
in numba to learn what you can or cannot use in the passed function.

As of right now, the numba engine can only be used with raw=True.

.. versionadded:: 2.2.0

engine_kwargs : dict
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/apply/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ def int_frame_const_col():
columns=["A", "B", "C"],
)
return df


@pytest.fixture(params=["python", "numba"])
def engine(request):
if request.param == "numba":
pytest.importorskip("numba")
return request.param


@pytest.fixture(params=[0, 1])
def apply_axis(request):
return request.param
Loading