-
Notifications
You must be signed in to change notification settings - Fork 135
Remove Unbroadcast
Op
#1286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Remove Unbroadcast
Op
#1286
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,7 +58,11 @@ | |
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
from pytensor.tensor.exceptions import NotScalarConstantError | ||
from pytensor.tensor.math import Dot, dot, maximum, minimum | ||
from pytensor.tensor.rewriting.basic import constant_folding, local_useless_switch | ||
from pytensor.tensor.rewriting.basic import ( | ||
broadcasted_by, | ||
constant_folding, | ||
local_useless_switch, | ||
) | ||
from pytensor.tensor.rewriting.elemwise import local_upcast_elemwise_constant_inputs | ||
from pytensor.tensor.rewriting.math import local_abs_merge, local_mul_switch_sink | ||
from pytensor.tensor.shape import shape | ||
|
@@ -1182,6 +1186,44 @@ def while_scan_merge_subtensor_last_element(fgraph, scan_node): | |
return subtensor_merge_replacements | ||
|
||
|
||
def _is_default_scan_buffer(x: TensorVariable) -> bool: | ||
node = x.owner | ||
|
||
if node is None: | ||
return False | ||
|
||
op = node.op | ||
if not ( | ||
isinstance(op, IncSubtensor) | ||
and op.set_instead_of_inc | ||
and op.idx_list == [slice(None, ps.int64)] | ||
): | ||
return False | ||
|
||
x, y, *_ = node.inputs | ||
if not (x.owner is not None and isinstance(x.owner.op, AllocEmpty)): | ||
return False | ||
|
||
# The value may have been broadcast to fill in the initial taps. | ||
# If the user specified outputs as: | ||
# x = scalar(); init = alloc(x, 2); | ||
# outputs_info=[init, taps=(-2, -1)] | ||
# Scan will generate an initial buffer that looks like | ||
# alloc_empty(2 + nsteps)[:2].set(alloc(x, 2)) | ||
# PyTensor will then rewrite it as: | ||
# alloc_empty(2 + nsteps)[:2].set(x) | ||
# When the initial value (x) is being broadcast by the set_subtensor | ||
# we can't recreate a newly sized buffer working with x alone | ||
# We want to check that: | ||
# 1. alloc_empty(2 + nsteps)[:2].broadcastable == x.broadcastable | ||
# But due to laziness we use the slightly more conservative check: | ||
# 2. alloc_empty(2 + nsteps).broadcastable == x.broadcastable | ||
if broadcasted_by(y, x): | ||
return False | ||
|
||
return True | ||
|
||
|
||
def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: bool): | ||
r"""Graph optimizer that reduces scan memory consumption. | ||
|
||
|
@@ -1520,51 +1562,28 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: | |
|
||
# 3.2 check orphane outputs to see if we can eliminate any | ||
required, not_required = scan_can_remove_outs(node.op, orphane_outs) | ||
# 3.3. compose replace pairs for those nodes that need not | ||
# to store everything in memory ( or ar orphane and required | ||
# by the inner function .. ) | ||
|
||
# 3.3. compose replace pairs for those nodes that need not store everything in memory | ||
# (or ar orphan but required by the inner function) | ||
replaced_outs = [] | ||
offset = 1 + op_info.n_seqs + op_info.n_mit_mot | ||
for idx, _val in enumerate(store_steps[op_info.n_mit_mot :]): | ||
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): | ||
i = idx + op_info.n_mit_mot | ||
if not (isinstance(_val, int) and _val <= 0 and i not in required): | ||
if idx + op_info.n_mit_mot in required: | ||
val = 1 | ||
else: | ||
val = _val | ||
if not (isinstance(val, int) and val <= 0 and i not in required): | ||
required_orphan = idx + op_info.n_mit_mot in required | ||
# If the memory for this output has been pre-allocated | ||
# before going into the scan op (by an alloc node) | ||
if idx < op_info.n_mit_sot + op_info.n_sit_sot: | ||
# In case the input is still an alloc node, we | ||
# actually have two options: | ||
# a) the input is a set_subtensor, in that case we | ||
# can replace the initial tensor by a slice, | ||
# b) it is not, and we simply take a slice of it. | ||
# TODO: commit change below with Razvan | ||
if ( | ||
nw_inputs[offset + idx].owner | ||
and isinstance(nw_inputs[offset + idx].owner.op, IncSubtensor) | ||
and nw_inputs[offset + idx].owner.op.set_instead_of_inc | ||
and isinstance( | ||
nw_inputs[offset + idx].owner.op.idx_list[0], slice | ||
) | ||
# Don't try to create a smart Alloc, if set_subtensor is broadcasting the fill value | ||
# As it happens in set_subtensor(empty(2)[:], 0) | ||
and not ( | ||
nw_inputs[offset + idx].ndim | ||
> nw_inputs[offset + idx].owner.inputs[1].ndim | ||
) | ||
): | ||
_nw_input = nw_inputs[offset + idx].owner.inputs[1] | ||
cval = pt.as_tensor_variable(val) | ||
initl = pt.as_tensor_variable(init_l[i]) | ||
tmp_idx = pt.switch(cval < initl, cval + initl, cval - initl) | ||
nw_input = expand_empty(_nw_input, tmp_idx) | ||
nw_input = nw_inputs[offset + idx] | ||
|
||
# Recreate default buffers with new size | ||
if _is_default_scan_buffer(nw_input): | ||
extra_size = 1 if required_orphan else val - init_l[i] | ||
nw_input = expand_empty(nw_input.owner.inputs[1], extra_size) | ||
# Otherwise, just trim with a slice | ||
else: | ||
tmp = pt.as_tensor_variable(val) | ||
initl = pt.as_tensor_variable(init_l[i]) | ||
tmp = maximum(tmp, initl) | ||
nw_input = nw_inputs[offset + idx][:tmp] | ||
stop = init_l[i] if required_orphan else val | ||
nw_input = nw_input[:stop] | ||
|
||
nw_inputs[offset + idx] = nw_input | ||
replaced_outs.append(op_info.n_mit_mot + idx) | ||
|
@@ -1588,7 +1607,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: | |
+ op_info.n_shared_outs | ||
) | ||
if nw_inputs[pos] == node.inputs[0]: | ||
nw_inputs[pos] = val | ||
nw_inputs[pos] = 1 if required_orphan else val | ||
odx = op_info.n_mit_mot + idx | ||
replaced_outs.append(odx) | ||
old_outputs += [ | ||
|
@@ -1600,37 +1619,22 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation: | |
], | ||
) | ||
] | ||
# 3.4. Recompute inputs for everything else based on the new | ||
# number of steps | ||
# 3.4. Recompute inputs for everything else based on the new number of steps | ||
if global_nsteps is not None: | ||
for idx, val in enumerate(store_steps[op_info.n_mit_mot :]): | ||
if val == 0: | ||
# val == 0 means that we want to keep all intermediate | ||
# results for that state, including the initial values. | ||
if idx < op_info.n_mit_sot + op_info.n_sit_sot: | ||
in_idx = offset + idx | ||
# Number of steps in the initial state | ||
initl = init_l[op_info.n_mit_mot + idx] | ||
|
||
# If the initial buffer has the form | ||
# inc_subtensor(zeros(...)[...], _nw_input) | ||
# we want to make the zeros tensor as small as | ||
# possible (nw_steps + initl), and call | ||
# inc_subtensor on that instead. | ||
# Otherwise, simply take 0:(nw_steps+initl). | ||
if ( | ||
nw_inputs[in_idx].owner | ||
and isinstance(nw_inputs[in_idx].owner.op, IncSubtensor) | ||
and isinstance( | ||
nw_inputs[in_idx].owner.op.idx_list[0], slice | ||
) | ||
): | ||
_nw_input = nw_inputs[in_idx].owner.inputs[1] | ||
nw_input = expand_empty(_nw_input, nw_steps) | ||
nw_inputs[in_idx] = nw_input | ||
nw_input = nw_inputs[in_idx] | ||
if _is_default_scan_buffer(nw_input): | ||
nw_input = expand_empty(nw_input.owner.inputs[1], nw_steps) | ||
else: | ||
# FIXME: This is never used | ||
nw_input = nw_inputs[in_idx][: (initl + nw_steps)] | ||
# Number of steps in the initial state | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is that FIXME in the old code still relevant? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed with these changes |
||
init_l_pt = pt.as_tensor(init_l[op_info.n_mit_mot + idx]) | ||
nw_input = nw_input[: (init_l_pt + nw_steps)] | ||
nw_inputs[in_idx] = nw_input | ||
|
||
elif ( | ||
idx < op_info.n_mit_sot + op_info.n_sit_sot + op_info.n_nit_sot | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.