Skip to content

Commit 9e603cf

Browse files
authored
Provide static output shape for constant arange
1 parent cb0758c commit 9e603cf

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

Diff for: pytensor/tensor/basic.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -3215,13 +3215,29 @@ def __init__(self, dtype):
32153215
self.dtype = dtype
32163216

32173217
def make_node(self, start, stop, step):
3218+
from math import ceil
3219+
32183220
start, stop, step = map(as_tensor_variable, (start, stop, step))
3221+
32193222
assert start.ndim == 0
32203223
assert stop.ndim == 0
32213224
assert step.ndim == 0
32223225

3226+
# if it is possible to directly determine the shape i.e static shape is present, we find it.
3227+
if (
3228+
isinstance(start, TensorConstant)
3229+
and isinstance(stop, TensorConstant)
3230+
and isinstance(step, TensorConstant)
3231+
):
3232+
length = max(
3233+
ceil((float(stop.data) - float(start.data)) / float(step.data)), 0
3234+
)
3235+
shape = (length,)
3236+
else:
3237+
shape = (None,)
3238+
32233239
inputs = [start, stop, step]
3224-
outputs = [tensor(dtype=self.dtype, shape=(None,))]
3240+
outputs = [tensor(dtype=self.dtype, shape=shape)]
32253241

32263242
return Apply(self, inputs, outputs)
32273243

Diff for: tests/tensor/test_basic.py

+7
Original file line numberDiff line numberDiff line change
@@ -2861,6 +2861,13 @@ def test_infer_shape(self, cast_policy):
28612861
assert np.all(f(2) == len(np.arange(0, 2)))
28622862
assert np.all(f(0) == len(np.arange(0, 0)))
28632863

2864+
def test_static_shape(self):
2865+
assert np.arange(1, 10).shape == arange(1, 10).type.shape
2866+
assert np.arange(10, 1, -1).shape == arange(10, 1, -1).type.shape
2867+
assert np.arange(1, -9, 2).shape == arange(1, -9, 2).type.shape
2868+
assert np.arange(1.3, 17.48, 2.67).shape == arange(1.3, 17.48, 2.67).type.shape
2869+
assert np.arange(-64, 64).shape == arange(-64, 64).type.shape
2870+
28642871

28652872
class TestNdGrid:
28662873
def setup_method(self):

0 commit comments

Comments
 (0)