@@ -709,29 +709,66 @@ def test_shape_tuple():
709
709
710
710
711
711
class TestVectorize :
712
+ @pytensor .config .change_flags (cxx = "" ) # For faster eval
712
713
def test_shape (self ):
713
- vec = tensor (shape = (None ,))
714
- mat = tensor (shape = (None , None ))
715
-
714
+ vec = tensor (shape = (None ,), dtype = "float64" )
715
+ mat = tensor (shape = (None , None ), dtype = "float64" )
716
716
node = shape (vec ).owner
717
- vect_node = vectorize_node (node , mat )
718
- assert equal_computations (vect_node .outputs , [shape (mat )])
719
717
718
+ [vect_out ] = vectorize_node (node , mat ).outputs
719
+ assert equal_computations (
720
+ [vect_out ], [broadcast_to (mat .shape [1 :], (* mat .shape [:1 ], 1 ))]
721
+ )
722
+
723
+ mat_test_value = np .ones ((5 , 3 ))
724
+ ref_fn = np .vectorize (lambda vec : np .asarray (vec .shape ), signature = "(vec)->(1)" )
725
+ np .testing .assert_array_equal (
726
+ vect_out .eval ({mat : mat_test_value }),
727
+ ref_fn (mat_test_value ),
728
+ )
729
+
730
+ mat = tensor (shape = (None , None ), dtype = "float64" )
731
+ tns = tensor (shape = (None , None , None , None ), dtype = "float64" )
732
+ node = shape (mat ).owner
733
+ [vect_out ] = vectorize_node (node , tns ).outputs
734
+ assert equal_computations (
735
+ [vect_out ], [broadcast_to (tns .shape [2 :], (* tns .shape [:2 ], 2 ))]
736
+ )
737
+
738
+ tns_test_value = np .ones ((4 , 6 , 5 , 3 ))
739
+ ref_fn = np .vectorize (
740
+ lambda vec : np .asarray (vec .shape ), signature = "(m1,m2)->(2)"
741
+ )
742
+ np .testing .assert_array_equal (
743
+ vect_out .eval ({tns : tns_test_value }),
744
+ ref_fn (tns_test_value ),
745
+ )
746
+
747
+ @pytensor .config .change_flags (cxx = "" ) # For faster eval
720
748
def test_reshape (self ):
721
749
x = scalar ("x" , dtype = int )
722
- vec = tensor (shape = (None ,))
723
- mat = tensor (shape = (None , None ))
750
+ vec = tensor (shape = (None ,), dtype = "float64" )
751
+ mat = tensor (shape = (None , None ), dtype = "float64" )
724
752
725
- shape = (2 , x )
753
+ shape = (- 1 , x )
726
754
node = reshape (vec , shape ).owner
727
- vect_node = vectorize_node (node , mat , shape )
728
- assert equal_computations (
729
- vect_node .outputs , [reshape (mat , (* mat .shape [:1 ], 2 , x ))]
755
+
756
+ [vect_out ] = vectorize_node (node , mat , shape ).outputs
757
+ assert equal_computations ([vect_out ], [reshape (mat , (* mat .shape [:1 ], - 1 , x ))])
758
+
759
+ x_test_value = 2
760
+ mat_test_value = np .ones ((5 , 6 ))
761
+ ref_fn = np .vectorize (
762
+ lambda x , vec : vec .reshape (- 1 , x ), signature = "(),(vec1)->(mat1,mat2)"
763
+ )
764
+ np .testing .assert_array_equal (
765
+ vect_out .eval ({x : x_test_value , mat : mat_test_value }),
766
+ ref_fn (x_test_value , mat_test_value ),
730
767
)
731
768
732
- new_shape = (5 , 2 , x )
733
- vect_node = vectorize_node (node , mat , new_shape )
734
- assert equal_computations (vect_node . outputs , [reshape (mat , new_shape )])
769
+ new_shape = (5 , - 1 , x )
770
+ [ vect_out ] = vectorize_node (node , mat , new_shape ). outputs
771
+ assert equal_computations ([ vect_out ] , [reshape (mat , new_shape )])
735
772
736
773
with pytest .raises (NotImplementedError ):
737
774
vectorize_node (node , vec , broadcast_to (as_tensor ([5 , 2 , x ]), (2 , 3 )))
0 commit comments