-
Notifications
You must be signed in to change notification settings - Fork 135
Implement several subtensor lift rewrites #1158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cbe0c96
to
9b47cee
Compare
9b47cee
to
d1b5784
Compare
d72b5c2
to
23870fa
Compare
a022812
to
580149b
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (90.97%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1158 +/- ##
==========================================
+ Coverage 82.02% 82.07% +0.05%
==========================================
Files 207 208 +1
Lines 49301 49517 +216
Branches 8747 8785 +38
==========================================
+ Hits 40440 40642 +202
- Misses 6695 6702 +7
- Partials 2166 2173 +7
🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I finally got around to reviewing this. Sorry @ricardoV94 for the huge delay! I left a few questions but it looks good.
@register_canonicalize("shape_unsafe") | ||
@register_specialize("shape_unsafe") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know the difference between canonicalize and specialize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
canonicalize happens before stabilize and before specialize. Canonicalize should simplify graphs convert different forms into a canonical one. Specialize is allowed to go bananas and produce very specialized code. Many rewrites we apply in multiple phases so they can interact with others that are specific to that phase
copy_stack_trace(node.outputs[0], subt_x) | ||
elem_inputs = elem.owner.inputs | ||
elem_bcast = elem.type.broadcastable | ||
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will the elemenwise inputs already have the same number of dimensions as the output at this point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Elemwise.make_node adds any expand_dims needed so that all inputs/output have the same ndim
[old_out] = node.outputs | ||
|
||
# Copy stack trace to new inputs | ||
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this function do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When there's a runtime error evaluating a node it shows the stack trace of where the original variable was first defined by the user, even if it was later replaced by rewrites and no longer really exists in the computational graph
|
||
For now rewrite is restricted to single axis of reduction, for simplicity. | ||
|
||
sum(x, axis=1)[0] -> sum(x[0], axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this rewrite handle the case where keepdims=True
? Or does that kwarg produce a different Op
than CAReduce
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keepdims is not part of CARedue, it's added as a separate expand_dims after the reduction.
|
||
if len(fgraph.clients[adv_subtensor]) > 1: | ||
# AdvancedSubtensor involves a full_copy, so we don't want to do it twice | ||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add pragma: no cover
to make codecov shut up
|
||
@node_rewriter([Subtensor]) | ||
def local_subtensor_of_adv_subtensor(fgraph, node): | ||
"""Lift a simple Subtensor through an AdvancedSubtensor, when basic index dimensions are to the left of any advanced ones. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This rewrite made me question something of the other commits. If this rewrite is lifting the subtensor out of the advanced subtensor operation, aren't some of your other rewrites actually lowering the subtensor so that it gets applied before the other ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when we say lifting we mean applying closer to the inputs of the function (before other operations). What you're saying would be sinking/lowering. In general we want to apply the subtensor before other elementwise operations, so as to compute fewer things). AdvancedIndexing could be a case to lower, if we end up having more rows than before, but we don't have any rewrites to reason about that
Split off Subtensor of Unbroadcast into its own rewrite
This allows reducing computations on batch dimensions by lifting simple indexing operations closer to the inputs.
An obvious example is:
📚 Documentation preview 📚: https://pytensor--1158.org.readthedocs.build/en/1158/