Skip to content

Fix slow dot in numba #1426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 27, 2025
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 27, 2025

Closes #1418

I confirmed with perf that numba is calling gemv, the problem was the castings astype to handle mixed or discrete dtypes always force a copy even when the cast dtype is the same as the input (discussed in numba/numba#10085)

We should check the other functions that make use of int_to_float_fn (mostly linalg stuff), and adapt accordingly. This PR is just to fix the dot case.

Benchmark results

Before
------------------------------------------------------------------------------------------------- benchmark: 3 tests -------------------------------------------------------------------------------------------------
Name (time in us)                              Min                   Max                Mean             StdDev              Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_mat_vec_dot_performance[float32]      48.5510 (1.0)      1,294.7070 (3.11)      61.3219 (1.0)      14.6329 (1.0)       60.6585 (1.0)       2.5140 (1.0)      231;2385       16.3074 (1.0)        9052           1
test_mat_vec_dot_performance[mixed]       150.2420 (3.09)       514.2840 (1.23)     172.9162 (2.82)     34.0547 (2.33)     159.1785 (2.62)     19.9170 (7.92)      298;300        5.7831 (0.35)       3162           1
test_mat_vec_dot_performance[float64]     158.5070 (3.26)       416.6410 (1.0)      181.0844 (2.95)     33.2207 (2.27)     165.9410 (2.74)     23.4185 (9.32)        73;59        5.5223 (0.34)        725           1
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

After
----------------------------------------------------------------------------------------------- benchmark: 3 tests ----------------------------------------------------------------------------------------------
Name (time in us)                             Min                 Max               Mean             StdDev             Median                IQR            Outliers  OPS (Kops/s)            Rounds  Iterations
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
test_mat_vec_dot_performance[float32]     23.5340 (1.0)       69.0690 (1.0)      30.1434 (1.0)       3.2675 (1.0)      29.1650 (1.0)       0.5820 (1.0)     1465;3237       33.1748 (1.0)       16409           1
test_mat_vec_dot_performance[float64]     40.3260 (1.71)     159.3380 (2.31)     51.0019 (1.69)      8.8649 (2.71)     51.2260 (1.76)      9.6455 (16.57)    1587;272       19.6071 (0.59)       6019           1
test_mat_vec_dot_performance[mixed]       40.5660 (1.72)     180.4480 (2.61)     53.0003 (1.76)     16.4452 (5.03)     50.8750 (1.74)     11.4110 (19.61)     923;845       18.8678 (0.57)       9525           1
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

It now matches the C-backend performance in my machine


📚 Documentation preview 📚: https://pytensor--1426.org.readthedocs.build/en/1426/

if all(
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
) and isinstance(np.dtype(out_dtype), np.floating):
if (
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old check doesn't work or stopped working at some point:

assert (
    isinstance(np.dtype("float64"), np.floating)
    or isinstance(np.float64, np.floating) 
)  # fails

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use this function anymore in dot, but it's used elsewhere

@numba_njit
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
if x_dtype == dot_dtype and y_dtype == dot_dtype:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need these branches otherwise, I get some failed to unify during numba compilation.
It doesn't like the pattern:

if x.dtype != dot_dtype:
  x = x.astype(dot_dtype)

When it's actually needed. Can be simplified once astype(copy=False) is implemented in numba

Copy link

codecov bot commented May 27, 2025

Codecov Report

Attention: Patch coverage is 81.25000% with 6 lines in your changes missing coverage. Please review.

Project coverage is 82.11%. Comparing base (261aaf3) to head (5f7a740).
Report is 10 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/basic.py 81.25% 4 Missing and 2 partials ⚠️
Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1426   +/-   ##
=======================================
  Coverage   82.11%   82.11%           
=======================================
  Files         211      211           
  Lines       49686    49743   +57     
  Branches     8813     8824   +11     
=======================================
+ Hits        40798    40847   +49     
- Misses       6710     6715    +5     
- Partials     2178     2181    +3     
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/basic.py 79.08% <81.25%> (-0.46%) ⬇️

... and 6 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good :-)

@ricardoV94 ricardoV94 merged commit 5a462e9 into pymc-devs:main May 27, 2025
73 of 74 checks passed
@ricardoV94 ricardoV94 deleted the fix_slow_dot_numba branch May 28, 2025 07:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement gemv numba dispatch
2 participants