13
13
from pymc .logprob .basic import conditional_logp , logp
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
- from pymc .pytensorf import compile_pymc , constant_fold , inputvars
16
+ from pymc .pytensorf import compile_pymc , constant_fold
17
17
from pymc .util import _get_seeds_per_chain , treedict
18
18
from pytensor import Mode , scan
19
19
from pytensor .compile import SharedVariable
20
- from pytensor .compile .builders import OpFromGraph
21
20
from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
22
- from pytensor .graph .replace import vectorize_graph
21
+ from pytensor .graph .replace import graph_replace , vectorize_graph
23
22
from pytensor .scan import map as scan_map
24
23
from pytensor .tensor import TensorType , TensorVariable
25
24
from pytensor .tensor .elemwise import Elemwise
@@ -686,31 +685,23 @@ def _add_reduce_batch_dependent_logps(
686
685
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
687
686
# Clone the inner RV graph of the Marginalized RV
688
687
marginalized_rvs_node = op .make_node (* inputs )
689
- inner_rvs = clone_replace (
688
+ marginalized_rv , * inner_rvs = clone_replace (
690
689
op .inner_outputs ,
691
690
replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
692
691
)
693
- marginalized_rv = inner_rvs [0 ]
694
692
695
693
# Obtain the joint_logp graph of the inner RV graph
696
- inner_rvs_to_values = {rv : rv .clone () for rv in inner_rvs }
697
- logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
694
+ inner_rv_values = dict (zip (inner_rvs , values ))
695
+ marginalized_vv = marginalized_rv .clone ()
696
+ rv_values = inner_rv_values | {marginalized_rv : marginalized_vv }
697
+ logps_dict = conditional_logp (rv_values = rv_values , ** kwargs )
698
698
699
699
# Reduce logp dimensions corresponding to broadcasted variables
700
- marginalized_logp = logps_dict .pop (inner_rvs_to_values [ marginalized_rv ] )
700
+ marginalized_logp = logps_dict .pop (marginalized_vv )
701
701
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
702
702
marginalized_rv .type , logps_dict .values ()
703
703
)
704
704
705
- # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
706
- # values of the marginalized RV
707
- # Some inputs are not root inputs (such as transformed projections of value variables)
708
- # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
709
- inputs = list (inputvars (inputs ))
710
- joint_logp_op = OpFromGraph (
711
- list (inner_rvs_to_values .values ()) + inputs , [joint_logp ], inline = True
712
- )
713
-
714
705
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
715
706
# each original dimension is independent so that it suffices to evaluate the graph
716
707
# n times, once with each possible value of the marginalized RV replicated across
@@ -729,17 +720,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
729
720
0 ,
730
721
)
731
722
732
- # Arbitrary cutoff to switch to Scan implementation to keep graph size under control
733
- # TODO: Try vectorize here
734
- if len (marginalized_rv_domain ) <= 10 :
735
- joint_logps = [
736
- joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
737
- for i in range (len (marginalized_rv_domain ))
738
- ]
739
- else :
740
-
723
+ try :
724
+ joint_logps = vectorize_graph (
725
+ joint_logp , replace = {marginalized_vv : marginalized_rv_domain_tensor }
726
+ )
727
+ except Exception :
728
+ # Fallback to Scan
741
729
def logp_fn (marginalized_rv_const , * non_sequences ):
742
- return joint_logp_op ( marginalized_rv_const , * non_sequences )
730
+ return graph_replace ( joint_logp , replace = { marginalized_vv : marginalized_rv_const } )
743
731
744
732
joint_logps , _ = scan_map (
745
733
fn = logp_fn ,
0 commit comments