Skip to content

ENH: Add numba engine to df.apply #55104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Oct 22, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
1fa802c
ENH: Add numba engine to df.apply
lithomas1 Sep 12, 2023
c6af7c9
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 12, 2023
0ac544d
complete?
lithomas1 Sep 14, 2023
31b9e20
wip: pass tests
lithomas1 Sep 19, 2023
6190772
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 24, 2023
55df7ad
fix existing tests
lithomas1 Sep 24, 2023
3c89b0f
go for green
lithomas1 Sep 25, 2023
1418d3e
fix checks?
lithomas1 Sep 25, 2023
c143c67
fix pyright
lithomas1 Sep 25, 2023
0d827c4
update docs
lithomas1 Sep 28, 2023
7129ee8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Sep 28, 2023
b0ba283
Merge branch 'main' into numba-apply
lithomas1 Sep 29, 2023
f4e80a6
eliminate a blank line
lithomas1 Sep 29, 2023
21e2186
update from code review + more tests
lithomas1 Oct 7, 2023
b60bef8
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 9, 2023
ba1d0e0
fix failing tests
lithomas1 Oct 10, 2023
088d27f
Simplify w/ context manager
lithomas1 Oct 12, 2023
60539a1
skip if no numba
lithomas1 Oct 12, 2023
76538d6
simplify more
lithomas1 Oct 12, 2023
cca34f9
specify dtypes
lithomas1 Oct 12, 2023
8b423bf
Merge branch 'main' of github.com:pandas-dev/pandas into numba-apply
lithomas1 Oct 15, 2023
f135def
Merge branch 'main' into numba-apply
lithomas1 Oct 16, 2023
b2e50d2
Merge branch 'numba-apply' of github.com:lithomas1/pandas into numba-…
lithomas1 Oct 16, 2023
f86024f
address code review
lithomas1 Oct 16, 2023
a15293d
add errors for invalid columns
lithomas1 Oct 19, 2023
8fe5d89
adjust message
lithomas1 Oct 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
561 changes: 561 additions & 0 deletions pandas/core/_numba/extensions.py

Large diffs are not rendered by default.

231 changes: 222 additions & 9 deletions pandas/core/apply.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
# pyright: reportUnusedImport=false
# Disabled since there's no way to do an ignore for both pyright
# and ruff, and ruff should be sufficient
# (The reason we need this is because the import of the numba extensions is unused
# but is necessary to register the extensions)
from __future__ import annotations

import abc
from collections import defaultdict
import functools
from functools import partial
import inspect
from typing import (
Expand Down Expand Up @@ -29,6 +35,7 @@
NDFrameT,
npt,
)
from pandas.compat._optional import import_optional_dependency
from pandas.errors import SpecificationError
from pandas.util._decorators import cache_readonly
from pandas.util._exceptions import find_stack_level
Expand Down Expand Up @@ -121,6 +128,8 @@ def __init__(
result_type: str | None,
*,
by_row: Literal[False, "compat", "_compat"] = "compat",
engine: str = "python",
engine_kwargs: dict[str, bool] | None = None,
args,
kwargs,
) -> None:
Expand All @@ -133,6 +142,9 @@ def __init__(
self.args = args or ()
self.kwargs = kwargs or {}

self.engine = engine
self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs

if result_type not in [None, "reduce", "broadcast", "expand"]:
raise ValueError(
"invalid value for result_type, must be one "
Expand Down Expand Up @@ -601,6 +613,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
result: Series, DataFrame, or None
Result when self.func is a list-like or dict-like, None otherwise.
"""

if self.engine == "numba":
raise NotImplementedError(
"The 'numba' engine doesn't support list-like/"
"dict likes of callables yet."
)

if self.axis == 1 and isinstance(self.obj, ABCDataFrame):
return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T

Expand Down Expand Up @@ -768,10 +787,16 @@ def __init__(
) -> None:
if by_row is not False and by_row != "compat":
raise ValueError(f"by_row={by_row} not allowed")
self.engine = engine
self.engine_kwargs = engine_kwargs
super().__init__(
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs
obj,
func,
raw,
result_type,
by_row=by_row,
engine=engine,
engine_kwargs=engine_kwargs,
args=args,
kwargs=kwargs,
)

# ---------------------------------------------------------------
Expand All @@ -792,6 +817,18 @@ def result_columns(self) -> Index:
def series_generator(self) -> Generator[Series, None, None]:
pass

@staticmethod
@functools.cache
@abc.abstractmethod
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
pass

@abc.abstractmethod
def apply_with_numba(self):
pass

@abc.abstractmethod
def wrap_results_for_axis(
self, results: ResType, res_index: Index
Expand All @@ -815,13 +852,12 @@ def values(self):
def apply(self) -> DataFrame | Series:
"""compute the results"""

if self.engine == "numba" and not self.raw:
raise ValueError(
"The numba engine in DataFrame.apply can only be used when raw=True"
)

# dispatch to handle list-like or dict-like
if is_list_like(self.func):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support lists of callables yet"
)
return self.apply_list_or_dict_like()

# all empty
Expand All @@ -830,17 +866,31 @@ def apply(self) -> DataFrame | Series:

# string dispatch
if isinstance(self.func, str):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support using "
"a string as the callable function"
)
return self.apply_str()

# ufunc
elif isinstance(self.func, np.ufunc):
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support "
"using a numpy ufunc as the callable function"
)
with np.errstate(all="ignore"):
results = self.obj._mgr.apply("apply", func=self.func)
# _constructor will retain self.index and self.columns
return self.obj._constructor_from_mgr(results, axes=results.axes)

# broadcasting
if self.result_type == "broadcast":
if self.engine == "numba":
raise NotImplementedError(
"the 'numba' engine doesn't support result_type='broadcast'"
)
return self.apply_broadcast(self.obj)

# one axis empty
Expand Down Expand Up @@ -997,7 +1047,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
return result

def apply_standard(self):
results, res_index = self.apply_series_generator()
if self.engine == "python":
results, res_index = self.apply_series_generator()
else:
results, res_index = self.apply_series_numba()

# wrap results
return self.wrap_results(results, res_index)
Expand All @@ -1021,6 +1074,18 @@ def apply_series_generator(self) -> tuple[ResType, Index]:

return results, res_index

def apply_series_numba(self):
if self.engine_kwargs.get("parallel", False):
raise NotImplementedError(
"Parallel apply is not supported when raw=False and engine='numba'"
)
if not self.obj.index.is_unique or not self.columns.is_unique:
raise NotImplementedError(
"The index/columns must be unique when raw=False and engine='numba'"
)
results = self.apply_with_numba()
return results, self.result_index

def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series:
from pandas import Series

Expand Down Expand Up @@ -1060,6 +1125,76 @@ class FrameRowApply(FrameApply):
def series_generator(self) -> Generator[Series, None, None]:
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns)))

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
from pandas import Series

# Dummy import just to make the extensions loaded in
# This isn't an entrypoint since we don't want users
# using Series/DF in numba code outside of apply
from pandas.core._numba.extensions import SeriesType # noqa: F401
from pandas.core._numba.extensions import maybe_cast_str

numba = import_optional_dependency("numba")

jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names, df_index):
results = {}
for j in range(values.shape[1]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for j in range(values.shape[1]):
for j in numba.prange(values.shape[1]):

? (and below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out!

I think for now it'll probably make better sense to disable parallel mode for now, since the dict in numba isn't thread-safe.

The overhead from the boxing/unboxing is also really high (99% of the time spent is there), so I doubt parallel will give a good speedup, at least for now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK makes sense. Would be good to put a TODO: comment explaining why we shouldn't use prange for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment.

# Create the series
ser = Series(
values[:, j], index=df_index, name=maybe_cast_str(col_names[j])
)
results[j] = jitted_udf(ser)
return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)
# Since numpy/numba doesn't support object array of stringswell
# we'll do a sketchy thing where if index._data is object
# we convert to string and directly set index._data to that,
# setting it back after we call the function
fixed_obj_colnames = False
orig_cols = self.columns.to_numpy()
if self.columns._data.dtype == object:
if not lib.is_string_array(orig_cols):
raise ValueError(
"The numba engine only supports "
"using string or numeric column names"
)
# Remember to set this back!!!
self.columns._data = orig_cols.astype("U")
fixed_obj_colnames = True

fixed_obj_index = False
orig_index = self.index.to_numpy()
if self.obj.index._data.dtype == object:
if not lib.is_string_array(orig_index):
raise ValueError(
"The numba engine only supports "
"using string or numeric index values"
)
# Remember to set this back!!!
self.obj.index._data = orig_index.astype("U")
fixed_obj_index = True
df_index = self.obj.index

res = dict(nb_func(self.values, self.columns, df_index))
if fixed_obj_colnames:
self.columns._data = orig_cols
if fixed_obj_index:
self.obj.index._data = orig_index
return res

@property
def result_index(self) -> Index:
return self.columns
Expand Down Expand Up @@ -1143,6 +1278,84 @@ def series_generator(self) -> Generator[Series, None, None]:
object.__setattr__(ser, "_name", name)
yield ser

@staticmethod
@functools.cache
def generate_numba_apply_func(
func, nogil=True, nopython=True, parallel=False
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]:
# Dummy import just to make the extensions loaded in
# This isn't an entrypoint since we don't want users
# using Series/DF in numba code outside of apply
from pandas import Series
from pandas.core._numba.extensions import SeriesType # noqa: F401
from pandas.core._numba.extensions import maybe_cast_str

numba = import_optional_dependency("numba")

jitted_udf = numba.extending.register_jitable(func)

@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel)
def numba_func(values, col_names_index, index):
results = {}
for i in range(values.shape[0]):
# Create the series
# TODO: values corrupted without the copy
ser = Series(
values[i].copy(),
index=col_names_index,
name=maybe_cast_str(index[i]),
)
results[i] = jitted_udf(ser)

return results

return numba_func

def apply_with_numba(self) -> dict[int, Any]:
nb_func = self.generate_numba_apply_func(
cast(Callable, self.func), **self.engine_kwargs
)

# Since numpy/numba doesn't support object array of stringswell
# we'll do a sketchy thing where if index._data is object
# we convert to string and directly set index._data to that,
# setting it back after we call the function
fixed_obj_colnames = False
orig_cols = self.columns.to_numpy()
if self.columns._data.dtype == object:
if not lib.is_string_array(orig_cols):
raise ValueError(
"The numba engine only supports "
"using string or numeric column names"
)
# Remember to set this back!!!
self.columns._data = orig_cols.astype("U")
fixed_obj_colnames = True

fixed_obj_index = False
orig_index = self.index.to_numpy()
if self.obj.index._data.dtype == object:
if not lib.is_string_array(orig_index):
raise ValueError(
"The numba engine only supports "
"using string or numeric index values"
)
# Remember to set this back!!!
self.obj.index._data = orig_index.astype("U")
fixed_obj_index = True

# Convert from numba dict to regular dict
# Our isinstance checks in the df constructor don't pass for numbas typed dict
res = dict(nb_func(self.values, self.columns, self.obj.index))

if fixed_obj_colnames:
self.columns._data = orig_cols

if fixed_obj_index:
self.obj.index._data = orig_index

return res

@property
def result_index(self) -> Index:
return self.index
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -10051,6 +10051,9 @@ def apply(
- nogil (release the GIL inside the JIT compiled function)
- parallel (try to apply the function in parallel over the DataFrame)

Note: Due to limitations within numba/how pandas interfaces with numba,
you should only use this if raw=True

Note: The numba compiler only supports a subset of
valid Python/numpy operations.

Expand All @@ -10060,8 +10063,6 @@ def apply(
<https://numba.pydata.org/numba-doc/dev/reference/numpysupported.html>`_
in numba to learn what you can or cannot use in the passed function.

As of right now, the numba engine can only be used with raw=True.

.. versionadded:: 2.2.0

engine_kwargs : dict
Expand Down
12 changes: 12 additions & 0 deletions pandas/tests/apply/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ def int_frame_const_col():
columns=["A", "B", "C"],
)
return df


@pytest.fixture(params=["python", "numba"])
def engine(request):
if request.param == "numba":
pytest.importorskip("numba")
return request.param


@pytest.fixture(params=[0, 1])
def apply_axis(request):
return request.param
Loading