5
5
import numpy as np
6
6
7
7
import pytensor
8
- import pytensor .scalar .basic as ps
9
8
from pytensor import compile
10
9
from pytensor .compile import optdb
11
10
from pytensor .graph .basic import Constant , Variable
14
13
copy_stack_trace ,
15
14
in2out ,
16
15
node_rewriter ,
16
+ out2in ,
17
17
)
18
18
from pytensor .raise_op import Assert
19
+ from pytensor .scalar import Add , ScalarConstant , ScalarType
20
+ from pytensor .scalar import constant as scalar_constant
19
21
from pytensor .tensor .basic import (
20
22
Alloc ,
21
23
Join ,
31
33
register_infer_shape ,
32
34
switch ,
33
35
)
36
+ from pytensor .tensor .basic import constant as tensor_constant
34
37
from pytensor .tensor .blockwise import Blockwise
35
38
from pytensor .tensor .elemwise import Elemwise
36
39
from pytensor .tensor .exceptions import NotScalarConstantError
@@ -588,11 +591,11 @@ def local_subtensor_remove_broadcastable_index(fgraph, node):
588
591
remove_dim = []
589
592
node_inputs_idx = 1
590
593
for dim , elem in enumerate (idx ):
591
- if isinstance (elem , ( ps . ScalarType ) ):
594
+ if isinstance (elem , ScalarType ):
592
595
# The idx is a ScalarType, ie a Type. This means the actual index
593
596
# is contained in node.inputs[1]
594
597
dim_index = node .inputs [node_inputs_idx ]
595
- if isinstance (dim_index , ps . ScalarConstant ):
598
+ if isinstance (dim_index , ScalarConstant ):
596
599
dim_index = dim_index .value
597
600
if dim_index in (0 , - 1 ) and node .inputs [0 ].broadcastable [dim ]:
598
601
remove_dim .append (dim )
@@ -770,7 +773,7 @@ def local_subtensor_make_vector(fgraph, node):
770
773
771
774
(idx ,) = idxs
772
775
773
- if isinstance (idx , ps . ScalarType | TensorType ):
776
+ if isinstance (idx , ScalarType | TensorType ):
774
777
old_idx , idx = idx , node .inputs [1 ]
775
778
assert idx .type .is_super (old_idx )
776
779
elif isinstance (node .op , AdvancedSubtensor1 ):
@@ -895,7 +898,7 @@ def local_set_to_inc_subtensor(fgraph, node):
895
898
and node .op .set_instead_of_inc
896
899
and node .inputs [1 ].owner
897
900
and isinstance (node .inputs [1 ].owner .op , Elemwise )
898
- and isinstance (node .inputs [1 ].owner .op .scalar_op , ps . Add )
901
+ and isinstance (node .inputs [1 ].owner .op .scalar_op , Add )
899
902
):
900
903
addn = node .inputs [1 ].owner
901
904
subn = None
@@ -1789,7 +1792,6 @@ def local_join_subtensors(fgraph, node):
1789
1792
return [merged_subtensors ]
1790
1793
1791
1794
1792
- @register_specialize
1793
1795
@node_rewriter (
1794
1796
[
1795
1797
Subtensor ,
@@ -1850,12 +1852,10 @@ def local_uint_constant_indices(fgraph, node):
1850
1852
if dtype == index_val .dtype :
1851
1853
continue
1852
1854
1853
- if index_val .ndim > 0 :
1854
- new_index = pytensor .tensor .as_tensor_variable (
1855
- index_val .astype (dtype ), dtype = dtype
1856
- )
1855
+ if isinstance (index .type , TensorType ):
1856
+ new_index = tensor_constant (index_val .astype (dtype ), dtype = dtype )
1857
1857
else :
1858
- new_index = ps . constant (index_val .astype (dtype ), dtype = dtype )
1858
+ new_index = scalar_constant (index_val .astype (dtype ), dtype = dtype )
1859
1859
1860
1860
new_indices [i ] = new_index
1861
1861
has_new_index = True
@@ -1877,6 +1877,20 @@ def local_uint_constant_indices(fgraph, node):
1877
1877
return [new_out ]
1878
1878
1879
1879
1880
+ compile .optdb .register (
1881
+ local_uint_constant_indices .__name__ ,
1882
+ out2in (local_uint_constant_indices ),
1883
+ # Python / C backends always cast indices to int64 internally.
1884
+ "numba" ,
1885
+ "jax" ,
1886
+ # After specialization and uncanonicalization
1887
+ # Other rewrites don't worry about the dtype of the indices
1888
+ # And can cause unnecessary passes of this optimization
1889
+ # Such as x.shape[np.int(0)] -> x.shape[np.uint(0)]
1890
+ position = 4 ,
1891
+ )
1892
+
1893
+
1880
1894
@register_canonicalize ("shape_unsafe" )
1881
1895
@register_stabilize ("shape_unsafe" )
1882
1896
@register_specialize ("shape_unsafe" )
0 commit comments