Skip to content

Commit 8158627

Browse files
committed
Use vectorize in finite_discrete_marginal_rv_logp
1 parent 63571f0 commit 8158627

File tree

1 file changed

+15
-27
lines changed

1 file changed

+15
-27
lines changed

Diff for: pymc_experimental/model/marginal_model.py

+15-27
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
from pymc.logprob.basic import conditional_logp, logp
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
16-
from pymc.pytensorf import compile_pymc, constant_fold, inputvars
16+
from pymc.pytensorf import compile_pymc, constant_fold
1717
from pymc.util import _get_seeds_per_chain, treedict
1818
from pytensor import Mode, scan
1919
from pytensor.compile import SharedVariable
20-
from pytensor.compile.builders import OpFromGraph
2120
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
22-
from pytensor.graph.replace import vectorize_graph
21+
from pytensor.graph.replace import graph_replace, vectorize_graph
2322
from pytensor.scan import map as scan_map
2423
from pytensor.tensor import TensorType, TensorVariable
2524
from pytensor.tensor.elemwise import Elemwise
@@ -686,31 +685,23 @@ def _add_reduce_batch_dependent_logps(
686685
def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
687686
# Clone the inner RV graph of the Marginalized RV
688687
marginalized_rvs_node = op.make_node(*inputs)
689-
inner_rvs = clone_replace(
688+
marginalized_rv, *inner_rvs = clone_replace(
690689
op.inner_outputs,
691690
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
692691
)
693-
marginalized_rv = inner_rvs[0]
694692

695693
# Obtain the joint_logp graph of the inner RV graph
696-
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
697-
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
694+
inner_rv_values = dict(zip(inner_rvs, values))
695+
marginalized_vv = marginalized_rv.clone()
696+
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
697+
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
698698

699699
# Reduce logp dimensions corresponding to broadcasted variables
700-
marginalized_logp = logps_dict.pop(inner_rvs_to_values[marginalized_rv])
700+
marginalized_logp = logps_dict.pop(marginalized_vv)
701701
joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps(
702702
marginalized_rv.type, logps_dict.values()
703703
)
704704

705-
# Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
706-
# values of the marginalized RV
707-
# Some inputs are not root inputs (such as transformed projections of value variables)
708-
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
709-
inputs = list(inputvars(inputs))
710-
joint_logp_op = OpFromGraph(
711-
list(inner_rvs_to_values.values()) + inputs, [joint_logp], inline=True
712-
)
713-
714705
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
715706
# each original dimension is independent so that it suffices to evaluate the graph
716707
# n times, once with each possible value of the marginalized RV replicated across
@@ -729,17 +720,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
729720
0,
730721
)
731722

732-
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
733-
# TODO: Try vectorize here
734-
if len(marginalized_rv_domain) <= 10:
735-
joint_logps = [
736-
joint_logp_op(marginalized_rv_domain_tensor[i], *values, *inputs)
737-
for i in range(len(marginalized_rv_domain))
738-
]
739-
else:
740-
723+
try:
724+
joint_logps = vectorize_graph(
725+
joint_logp, replace={marginalized_vv: marginalized_rv_domain_tensor}
726+
)
727+
except Exception:
728+
# Fallback to Scan
741729
def logp_fn(marginalized_rv_const, *non_sequences):
742-
return joint_logp_op(marginalized_rv_const, *non_sequences)
730+
return graph_replace(joint_logp, replace={marginalized_vv: marginalized_rv_const})
743731

744732
joint_logps, _ = scan_map(
745733
fn=logp_fn,

0 commit comments

Comments
 (0)