-
Notifications
You must be signed in to change notification settings - Fork 135
Fix slow dot in numba #1426
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix slow dot in numba #1426
Conversation
if all( | ||
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs | ||
) and isinstance(np.dtype(out_dtype), np.floating): | ||
if ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The old check doesn't work or stopped working at some point:
assert (
isinstance(np.dtype("float64"), np.floating)
or isinstance(np.float64, np.floating)
) # fails
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use this function anymore in dot, but it's used elsewhere
@numba_njit | ||
def dot(x, y): | ||
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype) | ||
if x_dtype == dot_dtype and y_dtype == dot_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need these branches otherwise, I get some failed to unify
during numba compilation.
It doesn't like the pattern:
if x.dtype != dot_dtype:
x = x.astype(dot_dtype)
When it's actually needed. Can be simplified once astype(copy=False)
is implemented in numba
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1426 +/- ##
=======================================
Coverage 82.11% 82.11%
=======================================
Files 211 211
Lines 49686 49743 +57
Branches 8813 8824 +11
=======================================
+ Hits 40798 40847 +49
- Misses 6710 6715 +5
- Partials 2178 2181 +3
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good :-)
Closes #1418
I confirmed with
perf
that numba is calling gemv, the problem was the castingsastype
to handle mixed or discrete dtypes always force a copy even when the cast dtype is the same as the input (discussed in numba/numba#10085)We should check the other functions that make use of
int_to_float_fn
(mostly linalg stuff), and adapt accordingly. This PR is just to fix the dot case.Benchmark results
It now matches the C-backend performance in my machine
📚 Documentation preview 📚: https://pytensor--1426.org.readthedocs.build/en/1426/