-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add numba engine to df.apply #55104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
1fa802c
c6af7c9
0ac544d
31b9e20
6190772
55df7ad
3c89b0f
1418d3e
c143c67
0d827c4
7129ee8
b0ba283
f4e80a6
21e2186
b60bef8
ba1d0e0
088d27f
60539a1
76538d6
cca34f9
8b423bf
f135def
b2e50d2
f86024f
a15293d
8fe5d89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,13 @@ | ||
# pyright: reportUnusedImport=false | ||
# Disabled since there's no way to do an ignore for both pyright | ||
# and ruff, and ruff should be sufficient | ||
# (The reason we need this is because the import of the numba extensions is unused | ||
# but is necessary to register the extensions) | ||
from __future__ import annotations | ||
|
||
import abc | ||
from collections import defaultdict | ||
import functools | ||
from functools import partial | ||
import inspect | ||
from typing import ( | ||
|
@@ -29,6 +35,7 @@ | |
NDFrameT, | ||
npt, | ||
) | ||
from pandas.compat._optional import import_optional_dependency | ||
from pandas.errors import SpecificationError | ||
from pandas.util._decorators import cache_readonly | ||
from pandas.util._exceptions import find_stack_level | ||
|
@@ -121,6 +128,8 @@ def __init__( | |
result_type: str | None, | ||
*, | ||
by_row: Literal[False, "compat", "_compat"] = "compat", | ||
engine: str = "python", | ||
engine_kwargs: dict[str, bool] | None = None, | ||
args, | ||
kwargs, | ||
) -> None: | ||
|
@@ -133,6 +142,9 @@ def __init__( | |
self.args = args or () | ||
self.kwargs = kwargs or {} | ||
|
||
self.engine = engine | ||
self.engine_kwargs = {} if engine_kwargs is None else engine_kwargs | ||
|
||
if result_type not in [None, "reduce", "broadcast", "expand"]: | ||
raise ValueError( | ||
"invalid value for result_type, must be one " | ||
|
@@ -601,6 +613,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series: | |
result: Series, DataFrame, or None | ||
Result when self.func is a list-like or dict-like, None otherwise. | ||
""" | ||
|
||
if self.engine == "numba": | ||
raise NotImplementedError( | ||
"The 'numba' engine doesn't support list-like/" | ||
"dict likes of callables yet." | ||
) | ||
|
||
if self.axis == 1 and isinstance(self.obj, ABCDataFrame): | ||
return self.obj.T.apply(self.func, 0, args=self.args, **self.kwargs).T | ||
|
||
|
@@ -768,10 +787,16 @@ def __init__( | |
) -> None: | ||
if by_row is not False and by_row != "compat": | ||
raise ValueError(f"by_row={by_row} not allowed") | ||
self.engine = engine | ||
self.engine_kwargs = engine_kwargs | ||
super().__init__( | ||
obj, func, raw, result_type, by_row=by_row, args=args, kwargs=kwargs | ||
obj, | ||
func, | ||
raw, | ||
result_type, | ||
by_row=by_row, | ||
engine=engine, | ||
engine_kwargs=engine_kwargs, | ||
args=args, | ||
kwargs=kwargs, | ||
) | ||
|
||
# --------------------------------------------------------------- | ||
|
@@ -792,6 +817,18 @@ def result_columns(self) -> Index: | |
def series_generator(self) -> Generator[Series, None, None]: | ||
pass | ||
|
||
@staticmethod | ||
@functools.cache | ||
@abc.abstractmethod | ||
def generate_numba_apply_func( | ||
func, nogil=True, nopython=True, parallel=False | ||
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def apply_with_numba(self): | ||
pass | ||
|
||
@abc.abstractmethod | ||
def wrap_results_for_axis( | ||
self, results: ResType, res_index: Index | ||
|
@@ -815,13 +852,12 @@ def values(self): | |
def apply(self) -> DataFrame | Series: | ||
"""compute the results""" | ||
|
||
if self.engine == "numba" and not self.raw: | ||
raise ValueError( | ||
"The numba engine in DataFrame.apply can only be used when raw=True" | ||
) | ||
|
||
# dispatch to handle list-like or dict-like | ||
if is_list_like(self.func): | ||
if self.engine == "numba": | ||
raise NotImplementedError( | ||
"the 'numba' engine doesn't support lists of callables yet" | ||
) | ||
return self.apply_list_or_dict_like() | ||
|
||
# all empty | ||
|
@@ -830,17 +866,31 @@ def apply(self) -> DataFrame | Series: | |
|
||
# string dispatch | ||
if isinstance(self.func, str): | ||
if self.engine == "numba": | ||
raise NotImplementedError( | ||
"the 'numba' engine doesn't support using " | ||
"a string as the callable function" | ||
) | ||
return self.apply_str() | ||
|
||
# ufunc | ||
elif isinstance(self.func, np.ufunc): | ||
if self.engine == "numba": | ||
raise NotImplementedError( | ||
"the 'numba' engine doesn't support " | ||
"using a numpy ufunc as the callable function" | ||
) | ||
with np.errstate(all="ignore"): | ||
results = self.obj._mgr.apply("apply", func=self.func) | ||
# _constructor will retain self.index and self.columns | ||
return self.obj._constructor_from_mgr(results, axes=results.axes) | ||
|
||
# broadcasting | ||
if self.result_type == "broadcast": | ||
if self.engine == "numba": | ||
raise NotImplementedError( | ||
"the 'numba' engine doesn't support result_type='broadcast'" | ||
) | ||
return self.apply_broadcast(self.obj) | ||
|
||
# one axis empty | ||
|
@@ -997,7 +1047,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame: | |
return result | ||
|
||
def apply_standard(self): | ||
results, res_index = self.apply_series_generator() | ||
if self.engine == "python": | ||
results, res_index = self.apply_series_generator() | ||
else: | ||
results, res_index = self.apply_series_numba() | ||
|
||
# wrap results | ||
return self.wrap_results(results, res_index) | ||
|
@@ -1021,6 +1074,10 @@ def apply_series_generator(self) -> tuple[ResType, Index]: | |
|
||
return results, res_index | ||
|
||
def apply_series_numba(self): | ||
results = self.apply_with_numba() | ||
return results, self.result_index | ||
|
||
def wrap_results(self, results: ResType, res_index: Index) -> DataFrame | Series: | ||
from pandas import Series | ||
|
||
|
@@ -1060,6 +1117,56 @@ class FrameRowApply(FrameApply): | |
def series_generator(self) -> Generator[Series, None, None]: | ||
return (self.obj._ixs(i, axis=1) for i in range(len(self.columns))) | ||
|
||
@staticmethod | ||
@functools.cache | ||
def generate_numba_apply_func( | ||
func, nogil=True, nopython=True, parallel=False | ||
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: | ||
from pandas import Series | ||
|
||
# Dummy import just to make the extensions loaded in | ||
# This isn't an entrypoint since we don't want users | ||
# using Series/DF in numba code outside of apply | ||
from pandas.core._numba.extensions import SeriesType # noqa: F401 | ||
|
||
numba = import_optional_dependency("numba") | ||
|
||
jitted_udf = numba.extending.register_jitable(func) | ||
|
||
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) | ||
def numba_func(values, col_names, df_index): | ||
results = {} | ||
for j in range(values.shape[1]): | ||
# Create the series | ||
ser = Series(values[:, j], index=df_index, name=str(col_names[j])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I restrict it to only allow string names. The cast to string is a quirk of my hack. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nvm, I think I see what you mean. I think I need to change it so it only cassts to string when the index is str dtype. |
||
results[j] = jitted_udf(ser) | ||
return results | ||
|
||
return numba_func | ||
|
||
def apply_with_numba(self) -> dict[int, Any]: | ||
nb_func = self.generate_numba_apply_func( | ||
cast(Callable, self.func), **self.engine_kwargs | ||
) | ||
orig_values = self.columns.to_numpy() | ||
fixed_cols = False | ||
if orig_values.dtype == object: | ||
if not lib.is_string_array(orig_values): | ||
raise ValueError( | ||
"The numba engine only supports " | ||
"using string or numeric column names" | ||
) | ||
col_names_values = orig_values.astype("U") | ||
# Remember to set this back! | ||
self.columns._data = col_names_values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does this need to be assigned to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh is this needed due to how the numba extension is defined? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, in Index, we don't allow numpy string dtypes, but my hack uses numpy string dtypes, since those already have a native representation in numba. Let me know if this is too hacky. |
||
fixed_cols = True | ||
df_index = self.obj.index | ||
|
||
res = dict(nb_func(self.values, self.columns, df_index)) | ||
if fixed_cols: | ||
self.columns._data = orig_values | ||
return res | ||
|
||
@property | ||
def result_index(self) -> Index: | ||
return self.columns | ||
|
@@ -1143,6 +1250,64 @@ def series_generator(self) -> Generator[Series, None, None]: | |
object.__setattr__(ser, "_name", name) | ||
yield ser | ||
|
||
@staticmethod | ||
@functools.cache | ||
def generate_numba_apply_func( | ||
func, nogil=True, nopython=True, parallel=False | ||
) -> Callable[[npt.NDArray, Index, Index], dict[int, Any]]: | ||
# Dummy import just to make the extensions loaded in | ||
# This isn't an entrypoint since we don't want users | ||
# using Series/DF in numba code outside of apply | ||
from pandas import Series | ||
from pandas.core._numba.extensions import SeriesType # noqa: F401 | ||
|
||
numba = import_optional_dependency("numba") | ||
|
||
jitted_udf = numba.extending.register_jitable(func) | ||
|
||
@numba.jit(nogil=nogil, nopython=nopython, parallel=parallel) | ||
def numba_func(values, col_names_index, index): | ||
results = {} | ||
for i in range(values.shape[0]): | ||
# Create the series | ||
# TODO: values corrupted without the copy | ||
ser = Series(values[i].copy(), index=col_names_index, name=index[i]) | ||
results[i] = jitted_udf(ser) | ||
|
||
return results | ||
|
||
return numba_func | ||
|
||
def apply_with_numba(self) -> dict[int, Any]: | ||
nb_func = self.generate_numba_apply_func( | ||
cast(Callable, self.func), **self.engine_kwargs | ||
) | ||
|
||
# Since numpy/numba doesn't support object array of stringswell | ||
# we'll do a sketchy thing where if index._data is object | ||
# we convert to string and directly set index._data to that, | ||
# setting it back after we call the function | ||
fixed_obj_dtype = False | ||
orig_data = self.columns.to_numpy() | ||
if self.columns._data.dtype == object: | ||
if not lib.is_string_array(orig_data): | ||
raise ValueError( | ||
"The numba engine only supports " | ||
"using string or numeric column names" | ||
) | ||
# Remember to set this back!!! | ||
self.columns._data = orig_data.astype("U") | ||
fixed_obj_dtype = True | ||
|
||
# Convert from numba dict to regular dict | ||
# Our isinstance checks in the df constructor don't pass for numbas typed dict | ||
res = dict(nb_func(self.values, self.columns, self.obj.index)) | ||
|
||
if fixed_obj_dtype: | ||
self.columns._data = orig_data | ||
|
||
return res | ||
|
||
@property | ||
def result_index(self) -> Index: | ||
return self.index | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? (and below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing this out!
I think for now it'll probably make better sense to disable parallel mode for now, since the dict in numba isn't thread-safe.
The overhead from the boxing/unboxing is also really high (99% of the time spent is there), so I doubt parallel will give a good speedup, at least for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK makes sense. Would be good to put a TODO: comment explaining why we shouldn't use prange for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment.