Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate static output shapes in Split and avoid copy in C-impl #1343

Merged
merged 4 commits into from
Apr 8, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 3, 2025

This makes Split (which shows up in the gradient of Join) much faster as it doesn't do useless copies.

I see a speedup of ~10x, obviously the comparison would scale with the ammount of copying that is now avoided

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.matrix("x", shape=(100, 200))
ys = pt.split(x, [10]*10, 10)
profile = None
fn = pytensor.function(
    [pytensor.In(x, borrow=True)], 
    [pytensor.Out(y, borrow=True) for y in ys],
    trust_input=True,
    profile=profile,
)
fn.dprint()
x_test = np.zeros((100, 200))
%timeit fn(x_test)
%timeit fn(x_test)
%timeit fn(x_test)
if profile:
    fn.profile.summary()

Also added static output shape and cleanup other methods of Split


📚 Documentation preview 📚: https://pytensor--1343.org.readthedocs.build/en/1343/

@ricardoV94 ricardoV94 force-pushed the faster_split branch 2 times, most recently from aa9b281 to 4331934 Compare April 3, 2025 11:59
@ricardoV94 ricardoV94 mentioned this pull request Apr 3, 2025
@ricardoV94 ricardoV94 changed the title Make Split C-impl return a view Propagate static output shapes in Split and avoid copy in C-impl Apr 3, 2025
Copy link

codecov bot commented Apr 3, 2025

Codecov Report

Attention: Patch coverage is 88.57143% with 4 lines in your changes missing coverage. Please review.

Project coverage is 82.01%. Comparing base (0f5da80) to head (f1d6ba0).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/basic.py 88.23% 3 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (88.57%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1343   +/-   ##
=======================================
  Coverage   82.01%   82.01%           
=======================================
  Files         203      203           
  Lines       48798    48813   +15     
  Branches     8685     8688    +3     
=======================================
+ Hits        40022    40035   +13     
- Misses       6625     6627    +2     
  Partials     2151     2151           
Files with missing lines Coverage Δ
pytensor/link/numba/dispatch/tensor_basic.py 88.59% <100.00%> (-0.20%) ⬇️
pytensor/tensor/basic.py 91.13% <88.23%> (-0.04%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some rinky-dink feedback. I'm not qualified to comment on the C code, but I tried my best.

@ricardoV94 ricardoV94 merged commit 4e59f21 into pymc-devs:main Apr 8, 2025
72 of 73 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants