Skip to content

Inplace on sit-sot / mit-sot when nsteps is symbolic #1283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ricardoV94 opened this issue Mar 10, 2025 · 0 comments
Open

Inplace on sit-sot / mit-sot when nsteps is symbolic #1283

ricardoV94 opened this issue Mar 10, 2025 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 10, 2025

Description

Follow up to #1281

In the numba backend we allow inplacing of the inner sit-sot / oldest mit-sot when we know the buffer is only large enough to store the most recent taps. However when n_steps to a Scan is symbolic PyTensor doesn't figure this out. Note how on the second graph, the inner scan Composite doesn't destroy the input *0

from pytensor import function, scan
import pytensor.tensor as pt

for constant_n_steps in (True, False):    
    print(f"{constant_n_steps=}")
    init_x = pt.vector("init_x", shape=(2,))
    n_steps = pt.iscalar("n_steps")

    def f_pow2(x_tm2, x_tm1):
        return 2 * x_tm1 + x_tm2
    
    trace, _ = scan(
        f_pow2,
        sequences=[],
        outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
        non_sequences=[],
        n_steps=10 if constant_n_steps else n_steps,
    )
    fn = function([init_x, n_steps], trace[-1], on_unused_input="ignore", mode="NUMBA")
    fn.dprint(print_memory_map=True, print_shape=True)
    
# constant_n_steps=True
# Subtensor{i} [id A] shape=() v={0: [0]} 3
#  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 2
#  │  ├─ 10 [id C] shape=()
#  │  └─ SetSubtensor{:stop} [id D] shape=(2,) d={0: [0]} 1
#  │     ├─ AllocEmpty{dtype='float64'} [id E] shape=(2,) 0
#  │     │  └─ 2 [id F] shape=()
#  │     ├─ init_x [id G] shape=(2,)
#  │     └─ 2 [id H] shape=()
#  └─ 1 [id I] shape=()

# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
#  ← Composite{((2.0 * i0) + i1)} [id J] shape=() d={0: [1]}
#     ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id D]
#     └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id D]


# constant_n_steps=False
# Subtensor{i} [id A] shape=() v={0: [0]} 5
#  ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] shape=(?,) d={0: [1]} 4
#  │  ├─ Composite{...}.0 [id C] shape=() 0
#  │  │  └─ n_steps [id D] shape=()
#  │  └─ SetSubtensor{:stop} [id E] shape=(?,) d={0: [0]} 3
#  │     ├─ AllocEmpty{dtype='float64'} [id F] shape=(?,) 2
#  │     │  └─ Composite{...}.2 [id C] shape=() 0
#  │     │     └─ ···
#  │     ├─ init_x [id G] shape=(2,)
#  │     └─ 2 [id H] shape=()
#  └─ ScalarFromTensor [id I] shape=() 1
#     └─ Composite{...}.1 [id C] shape=() 0
#        └─ ···

# Inner graphs:
# Scan{scan_fn, while_loop=False, inplace=all} [id B] d={0: [1]}
#  ← Composite{((2.0 * i0) + i1)} [id J] shape=()
#     ├─ *1-<Scalar(float64, shape=())> [id K] shape=() -> [id E]
#     └─ *0-<Scalar(float64, shape=())> [id L] shape=() -> [id E]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant