Skip to content

Fix bug in infer_shape of Blockwise(Subtensor) #1353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 9, 2025

Also adds some rewrites to make sure we end up with a clean gradient graph for a batched Conv1D.

The bug itself showed up in pymc-labs/pymc-marketing#1583


📚 Documentation preview 📚: https://pytensor--1353.org.readthedocs.build/en/1353/

Sometimes `_create_dummy_core_node` can create a multi-node graph, where the root inputs are not `node.inputs`. Then infer_shape may bypass the intermediate nodes. This was the case with Subtensor, which introduces `ScalarFromTensor` nodes, but ignores them in the shape graph (for a cleaner graph)
Copy link

codecov bot commented Apr 9, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.04%. Comparing base (4e59f21) to head (0aeda21).
Report is 10 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1353      +/-   ##
==========================================
+ Coverage   82.02%   82.04%   +0.02%     
==========================================
  Files         203      203              
  Lines       48845    48856      +11     
  Branches     8691     8693       +2     
==========================================
+ Hits        40067    40086      +19     
+ Misses       6627     6619       -8     
  Partials     2151     2151              
Files with missing lines Coverage Δ
pytensor/tensor/basic.py 91.17% <100.00%> (+0.03%) ⬆️
pytensor/tensor/blockwise.py 85.64% <100.00%> (ø)
pytensor/tensor/rewriting/blockwise.py 96.68% <100.00%> (+0.28%) ⬆️

... and 4 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I get the objective, but the code looks good

@@ -264,9 +264,13 @@ class TestOpWithInferShape(Op):
def make_node(self, a, b):
assert a.type.ndim == 1
assert b.type.ndim == 1
# Simulate make_node that introduces operations on inputs
a_identity = a.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does copying the input simulate an operation on the inputs? The apply takes the copies as inputs, so any intermediate operations (the f in a_identity = f(a)) would be lost right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After rewrites, but the Blockwise.infer_shape acts on the returned graph immediately.

The thing Blockwise is trying to is to figure out if the the core_shape of the Op depends on the values of the inputs, or can be guessed just based on their shapes. For that it calls infer_shape on the core op with dummy core variables and then checks if those are part of the returned shape graph. If they are, then it can't really use it, because it means the shape may vary over iterations (say a Blockwise slice Subtensor with batch start points).

If they are not used in the shape of the graph, then it means only the core shape is needed which is fine to use. This would be the case of a Blockwise(Dirichlet), where only the core shape (length of alpha), but not their value is needed. (We don't blockwise RVs but you get the idea).

Anyway, the logic to figure out if the core values are needed was to create a dummy node and then checking if the node inputs were in the graph. But this failed when the dummy node didn't really use the dummy variables (because it added extra nodes, like DimShuffle or ScalarFromTensor in the case that actually failed). Identity here is an easy way to test this without having to change anything else in the test.

@ricardoV94 ricardoV94 force-pushed the fix_blockwise_infer_shape branch from 1926761 to 0aeda21 Compare April 9, 2025 12:09
@ricardoV94 ricardoV94 merged commit 0c398e3 into pymc-devs:main Apr 9, 2025
73 checks passed
@ricardoV94 ricardoV94 deleted the fix_blockwise_infer_shape branch April 9, 2025 13:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants