diff --git a/docs/api_reference.rst b/docs/api_reference.rst index f2219e08..5d4f62bb 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -47,6 +47,16 @@ Distributions histogram_approximation +Transforms +========== + +.. currentmodule:: pymc_extras.distributions.transforms +.. autosummary:: + :toctree: generated/ + + PartialOrder + + Utils ===== diff --git a/pymc_extras/distributions/__init__.py b/pymc_extras/distributions/__init__.py index 046e3b60..783154ba 100644 --- a/pymc_extras/distributions/__init__.py +++ b/pymc_extras/distributions/__init__.py @@ -26,6 +26,7 @@ from pymc_extras.distributions.histogram_utils import histogram_approximation from pymc_extras.distributions.multivariate import R2D2M2CP from pymc_extras.distributions.timeseries import DiscreteMarkovChain +from pymc_extras.distributions.transforms import PartialOrder __all__ = [ "Chi", @@ -37,4 +38,5 @@ "R2D2M2CP", "Skellam", "histogram_approximation", + "PartialOrder", ] diff --git a/pymc_extras/distributions/transforms/__init__.py b/pymc_extras/distributions/transforms/__init__.py new file mode 100644 index 00000000..cb02d0f5 --- /dev/null +++ b/pymc_extras/distributions/transforms/__init__.py @@ -0,0 +1,3 @@ +from pymc_extras.distributions.transforms.partial_order import PartialOrder + +__all__ = ["PartialOrder"] diff --git a/pymc_extras/distributions/transforms/partial_order.py b/pymc_extras/distributions/transforms/partial_order.py new file mode 100644 index 00000000..d63c51ee --- /dev/null +++ b/pymc_extras/distributions/transforms/partial_order.py @@ -0,0 +1,227 @@ +# Copyright 2025 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt + +from pymc.logprob.transforms import Transform + +__all__ = ["PartialOrder"] + + +def dtype_minval(dtype): + """Find the minimum value for a given dtype""" + return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min + + +def padded_where(x, to_len, padval=-1): + """A padded version of np.where""" + w = np.where(x) + return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)]) + + +class PartialOrder(Transform): + """Create a PartialOrder transform + + A more flexible version of the pymc ordered transform that + allows specifying a (strict) partial order on the elements. + + Examples + -------- + .. code:: python + + import numpy as np + import pymc as pm + import pymc_extras as pmx + + # Define two partial orders on 4 elements + # am[i,j] = 1 means i < j + adj_mats = np.array([ + # 0 < {1, 2} < 3 + [[0, 1, 1, 0], + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 0]], + + # 1 < 0 < 3 < 2 + [[0, 0, 0, 1], + [1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 1, 0]], + ]) + + # Create the partial order from the adjacency matrices + po = pmx.PartialOrder(adj_mats) + + with pm.Model() as model: + # Generate 3 samples from both partial orders + pm.Normal("po_vals", shape=(3,2,4), transform=po, + initval=po.initvals((3,2,4))) + + idata = pm.sample() + + # Verify that for first po, the zeroth element is always the smallest + assert (idata.posterior['po_vals'][:,:,:,0,0] < + idata.posterior['po_vals'][:,:,:,0,1:]).all() + + # Verify that for second po, the second element is always the largest + assert (idata.posterior['po_vals'][:,:,:,1,2] >= + idata.posterior['po_vals'][:,:,:,1,:]).all() + + Technical notes + ---------------- + Partial order needs to be strict, i.e. without equalities. + A DAG defining the partial order is sufficient, as transitive closure is automatically computed. + Code works in O(N*D) in runtime, but takes O(N^3) in initialization, + where N is the number of nodes in the dag and D is the maximum + in-degree of a node in the transitive reduction. + """ + + name = "partial_order" + + def __init__(self, adj_mat): + """ + Initialize the PartialOrder transform + + Parameters + ---------- + adj_mat: ndarray + adjacency matrix for the DAG that generates the partial order, + where ``adj_mat[i][j] = 1`` denotes ``i < j``. + Note this also accepts multiple DAGs if RV is multidimensional + """ + + # Basic input checks + if adj_mat.ndim < 2: + raise ValueError("Adjacency matrix must have at least 2 dimensions") + if adj_mat.shape[-2] != adj_mat.shape[-1]: + raise ValueError("Adjacency matrix is not square") + if adj_mat.min() != 0 or adj_mat.max() != 1: + raise ValueError("Adjacency matrix must contain only 0s and 1s") + + # Create index over the first ellipsis dimensions + idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]]) + + # Transitive closure using Floyd-Warshall + tc = adj_mat.astype(bool) + for k in range(tc.shape[-1]): + tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :]) + + # Check if the dag is acyclic + if np.any(tc.diagonal(axis1=-2, axis2=-1)): + raise ValueError("Partial order contains equalities") + + # Transitive reduction using the closure + # This gives the minimum description of the partial order + # This is to minmax the input degree + adj_mat = tc * (1 - np.matmul(tc, tc)) + + # Find the maximum in-degree of the reduced dag + dag_idim = adj_mat.sum(axis=-2).max() + + # Topological sort + ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int) + dm = adj_mat.copy() + for i in range(adj_mat.shape[1]): + assert dm.sum(axis=-2).min() == 0 # DAG is acyclic + nind = np.argmin(dm.sum(axis=-2), axis=-1) + dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again + dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show + ts_inds[(*idx, i)] = nind + self.ts_inds = ts_inds + + # Change the dag to adjacency lists (with -1 for NA) + dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim) + self.dag = np.swapaxes(dag_T, -2, -1) + self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1) + + def initvals(self, shape=None, lower=-1, upper=1): + """ + Create a set of appropriate initial values for the variable. + NB! It is important that proper initial values are used, + as only properly ordered values are in the range of the transform. + + Parameters + ---------- + shape: tuple, default None + shape of the initial values. If None, adj_mat[:-1] is used + lower: float, default -1 + lower bound for the initial values + upper: float, default 1 + upper bound for the initial values + + Returns + ------- + vals: ndarray + initial values for the transformed variable + """ + + if shape is None: + shape = self.dag.shape[:-1] + + if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]: + raise ValueError("Shape must match the shape of the adjacency matrix") + + # Create the initial values + vals = np.linspace(lower, upper, self.dag.shape[-2]) + inds = np.argsort(self.ts_inds, axis=-1) + ivals = vals[inds] + + # Expand the initial values to the extra dimensions + extra_dims = shape[: -len(self.dag.shape[:-1])] + ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1]))) + + return ivals + + def backward(self, value, *inputs): + minv = dtype_minval(value.dtype) + x = pt.concatenate( + [pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1 + ) + + # Indices to allow broadcasting the max over the last dimension + idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]]) + idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx) + + # Has to be done stepwise as next steps depend on previous values + # Also has to be done in topological order, hence the ts_inds + for i in range(self.dag.shape[-2]): + tsi = self.ts_inds[..., i] + if len(tsi.shape) == 0: + tsi = int(tsi) # if shape 0, it's a scalar + ni = (*idx, tsi) # i-th node in topological order + eni = (Ellipsis, *ni) + ist = self.is_start[ni] + + mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1) + x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni]))) + return x[..., :-1] + + def forward(self, value, *inputs): + y = pt.zeros_like(value) + + minv = dtype_minval(value.dtype) + vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1) + + # Indices to allow broadcasting the max over the last dimension + idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]]) + idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx) + + y = self.is_start * value + (1 - self.is_start) * ( + pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1)) + ) + + return y + + def log_jac_det(self, value, *inputs): + return pt.sum(value * (1 - self.is_start), axis=-1) diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py new file mode 100644 index 00000000..4638d01a --- /dev/null +++ b/tests/distributions/test_transform.py @@ -0,0 +1,77 @@ +# Copyright 2025 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pymc as pm + +from pymc_extras.distributions.transforms import PartialOrder + + +class TestPartialOrder: + adj_mats = np.array( + [ + # 0 < {1, 2} < 3 + [[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]], + # 1 < 0 < 3 < 2 + [[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]], + ] + ) + + valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float) + + # Test that forward and backward are inverses of eachother + # And that it works when extra dimensions are added in data + def test_forward_backward_dimensionality(self): + po = PartialOrder(self.adj_mats) + po0 = PartialOrder(self.adj_mats[0]) + vv = self.valid_values + vv0 = self.valid_values[0] + + testsets = [ + (vv, po), + (po.initvals(), po), + (vv0, po0), + (po0.initvals(), po0), + (np.tile(vv0, (2, 1)), po0), + (np.tile(vv0, (2, 3, 2, 1)), po0), + (np.tile(vv, (2, 3, 2, 1, 1)), po), + ] + + for vv, po in testsets: + fw = po.forward(vv) + bw = po.backward(fw) + np.testing.assert_allclose(bw.eval(), vv) + + def test_sample_model(self): + po = PartialOrder(self.adj_mats) + with pm.Model() as model: + x = pm.Normal( + "x", + size=(3, 2, 4), + transform=po, + initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1), + ) + idata = pm.sample() + + # Check that the order constraints are satisfied + # Move chain, draw and "3" dimensions to the back + xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2) + x0 = xvs[0] # 0 < {1, 2} < 3 + assert ( + (x0[0] < x0[1]).all() + and (x0[0] < x0[2]).all() + and (x0[1] < x0[3]).all() + and (x0[2] < x0[3]).all() + ) + x1 = xvs[1] # 1 < 0 < 3 < 2 + assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()