1
+ import warnings
1
2
from collections .abc import Sequence
2
3
from copy import copy
3
4
from textwrap import dedent
19
20
from pytensor .misc .frozendict import frozendict
20
21
from pytensor .printing import Printer , pprint
21
22
from pytensor .scalar import get_scalar_type
23
+ from pytensor .scalar .basic import Composite , transfer_type , upcast
22
24
from pytensor .scalar .basic import bool as scalar_bool
23
25
from pytensor .scalar .basic import identity as scalar_identity
24
- from pytensor .scalar .basic import transfer_type , upcast
25
26
from pytensor .tensor import elemwise_cgen as cgen
26
27
from pytensor .tensor import get_vector_length
27
28
from pytensor .tensor .basic import _get_vector_length , as_tensor_variable
@@ -364,6 +365,7 @@ def __init__(
364
365
self .name = name
365
366
self .scalar_op = scalar_op
366
367
self .inplace_pattern = inplace_pattern
368
+ self .ufunc = None
367
369
self .destroy_map = {o : [i ] for o , i in self .inplace_pattern .items ()}
368
370
369
371
if nfunc_spec is None :
@@ -375,14 +377,12 @@ def __init__(
375
377
def __getstate__ (self ):
376
378
d = copy (self .__dict__ )
377
379
d .pop ("ufunc" )
378
- d .pop ("nfunc" )
379
- d .pop ("__epydoc_asRoutine" , None )
380
380
return d
381
381
382
382
def __setstate__ (self , d ):
383
+ d .pop ("nfunc" , None ) # This used to be stored in the Op, not anymore
383
384
super ().__setstate__ (d )
384
385
self .ufunc = None
385
- self .nfunc = None
386
386
self .inplace_pattern = frozendict (self .inplace_pattern )
387
387
388
388
def get_output_info (self , * inputs ):
@@ -623,31 +623,49 @@ def transform(r):
623
623
624
624
return ret
625
625
626
- def prepare_node (self , node , storage_map , compute_map , impl ):
627
- # Postpone the ufunc building to the last minutes due to:
628
- # - NumPy ufunc support only up to 32 operands (inputs and outputs)
629
- # But our c code support more.
630
- # - nfunc is reused for scipy and scipy is optional
631
- if (len (node .inputs ) + len (node .outputs )) > 32 and impl == "py" :
632
- impl = "c"
633
-
634
- if getattr (self , "nfunc_spec" , None ) and impl != "c" :
635
- self .nfunc = import_func_from_string (self .nfunc_spec [0 ])
636
-
626
+ def _create_node_ufunc (self , node ) -> None :
637
627
if (
638
- ( len ( node . inputs ) + len ( node . outputs )) <= 32
639
- and ( self . nfunc is None or self . scalar_op . nin != len ( node . inputs ))
640
- and self . ufunc is None
641
- and impl == "py"
628
+ self . nfunc_spec is not None
629
+ # Some scalar Ops like `Add` allow for a variable number of inputs,
630
+ # whereas the numpy counterpart does not.
631
+ and len ( node . inputs ) == self . nfunc_spec [ 1 ]
642
632
):
633
+ # Do we really need to cache this import in the Op?
634
+ # If it's so costly, just memorize `import_func_from_string`
635
+ ufunc = import_func_from_string (self .nfunc_spec [0 ])
636
+ if ufunc is None :
637
+ raise ValueError (
638
+ f"Could not import ufunc { self .nfunc_spec [0 ]} for { self } "
639
+ )
640
+
641
+ elif self .ufunc is not None :
642
+ # Cached before
643
+ ufunc = self .ufunc
644
+
645
+ else :
646
+ if (len (node .inputs ) + len (node .outputs )) > 32 :
647
+ if isinstance (self .scalar_op , Composite ):
648
+ warnings .warn (
649
+ "Trying to create a Python Composite Elemwise function with more than 32 operands.\n "
650
+ "This operation should not have been introduced if the C-backend is not properly setup. "
651
+ 'Make sure it is, or disable it by setting pytensor.config.cxx = "" (empty string).\n '
652
+ "Alternatively, consider using an optional backend like NUMBA or JAX, by setting "
653
+ '`pytensor.config.mode = "NUMBA" (or "JAX").'
654
+ )
655
+ else :
656
+ warnings .warn (
657
+ f"Trying to create a Python Elemwise function for the scalar Op { self .scalar_op } "
658
+ f"with more than 32 operands. This will likely fail."
659
+ )
660
+
643
661
ufunc = np .frompyfunc (
644
662
self .scalar_op .impl , len (node .inputs ), self .scalar_op .nout
645
663
)
646
- if self .scalar_op .nin > 0 :
647
- # We can reuse it for many nodes
664
+ if self .scalar_op .nin > 0 : # Default in base class is -1
665
+ # Op has constant signature, so we can reuse ufunc for many nodes. Cache it.
648
666
self .ufunc = ufunc
649
- else :
650
- node .tag .ufunc = ufunc
667
+
668
+ node .tag .ufunc = ufunc
651
669
652
670
# Numpy ufuncs will sometimes perform operations in
653
671
# float16, in particular when the input is int8.
@@ -669,6 +687,11 @@ def prepare_node(self, node, storage_map, compute_map, impl):
669
687
char = np .sctype2char (out_dtype )
670
688
sig = char * node .nin + "->" + char * node .nout
671
689
node .tag .sig = sig
690
+
691
+ def prepare_node (self , node , storage_map , compute_map , impl ):
692
+ if impl == "py" :
693
+ self ._create_node_ufunc (node )
694
+
672
695
node .tag .fake_node = Apply (
673
696
self .scalar_op ,
674
697
[
@@ -684,71 +707,36 @@ def prepare_node(self, node, storage_map, compute_map, impl):
684
707
self .scalar_op .prepare_node (node .tag .fake_node , None , None , impl )
685
708
686
709
def perform (self , node , inputs , output_storage ):
687
- if (len (node .inputs ) + len (node .outputs )) > 32 :
688
- # Some versions of NumPy will segfault, other will raise a
689
- # ValueError, if the number of operands in an ufunc is more than 32.
690
- # In that case, the C version should be used, or Elemwise fusion
691
- # should be disabled.
692
- # FIXME: This no longer calls the C implementation!
693
- super ().perform (node , inputs , output_storage )
710
+ ufunc = getattr (node .tag , "ufunc" , None )
711
+ if ufunc is None :
712
+ self ._create_node_ufunc (node )
713
+ ufunc = node .tag .ufunc
694
714
695
715
self ._check_runtime_broadcast (node , inputs )
696
716
697
- ufunc_args = inputs
698
717
ufunc_kwargs = {}
699
- # We supported in the past calling manually op.perform.
700
- # To keep that support we need to sometimes call self.prepare_node
701
- if self .nfunc is None and self .ufunc is None :
702
- self .prepare_node (node , None , None , "py" )
703
- if self .nfunc and len (inputs ) == self .nfunc_spec [1 ]:
704
- ufunc = self .nfunc
705
- nout = self .nfunc_spec [2 ]
706
- if hasattr (node .tag , "sig" ):
707
- ufunc_kwargs ["sig" ] = node .tag .sig
708
- # Unfortunately, the else case does not allow us to
709
- # directly feed the destination arguments to the nfunc
710
- # since it sometimes requires resizing. Doing this
711
- # optimization is probably not worth the effort, since we
712
- # should normally run the C version of the Op.
713
- else :
714
- # the second calling form is used because in certain versions of
715
- # numpy the first (faster) version leads to segfaults
716
- if self .ufunc :
717
- ufunc = self .ufunc
718
- elif not hasattr (node .tag , "ufunc" ):
719
- # It happen that make_thunk isn't called, like in
720
- # get_underlying_scalar_constant_value
721
- self .prepare_node (node , None , None , "py" )
722
- # prepare_node will add ufunc to self or the tag
723
- # depending if we can reuse it or not. So we need to
724
- # test both again.
725
- if self .ufunc :
726
- ufunc = self .ufunc
727
- else :
728
- ufunc = node .tag .ufunc
729
- else :
730
- ufunc = node .tag .ufunc
731
-
732
- nout = ufunc .nout
718
+ if hasattr (node .tag , "sig" ):
719
+ ufunc_kwargs ["sig" ] = node .tag .sig
733
720
734
- variables = ufunc (* ufunc_args , ** ufunc_kwargs )
721
+ outputs = ufunc (* inputs , ** ufunc_kwargs )
735
722
736
- if nout == 1 :
737
- variables = [ variables ]
723
+ if not isinstance ( outputs , tuple ) :
724
+ outputs = ( outputs ,)
738
725
739
- for i , (variable , storage , nout ) in enumerate (
740
- zip (variables , output_storage , node .outputs )
726
+ for i , (out , out_storage , node_out ) in enumerate (
727
+ zip (outputs , output_storage , node .outputs )
741
728
):
742
- storage [0 ] = variable = np .asarray (variable , dtype = nout .dtype )
729
+ # Numpy frompyfunc always returns object arrays
730
+ out_storage [0 ] = out = np .asarray (out , dtype = node_out .dtype )
743
731
744
732
if i in self .inplace_pattern :
745
- odat = inputs [self .inplace_pattern [i ]]
746
- odat [...] = variable
747
- storage [0 ] = odat
733
+ inp = inputs [self .inplace_pattern [i ]]
734
+ inp [...] = out
735
+ out_storage [0 ] = inp
748
736
749
737
# numpy.real return a view!
750
- if not variable .flags .owndata :
751
- storage [0 ] = variable .copy ()
738
+ if not out .flags .owndata :
739
+ out_storage [0 ] = out .copy ()
752
740
753
741
@staticmethod
754
742
def _check_runtime_broadcast (node , inputs ):
0 commit comments