Skip to content

Commit be4d843

Browse files
authored
Added a PartialOrder transform (#444)
1 parent 4431749 commit be4d843

File tree

5 files changed

+319
-0
lines changed

5 files changed

+319
-0
lines changed

docs/api_reference.rst

+10
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,16 @@ Distributions
4747
histogram_approximation
4848

4949

50+
Transforms
51+
==========
52+
53+
.. currentmodule:: pymc_extras.distributions.transforms
54+
.. autosummary::
55+
:toctree: generated/
56+
57+
PartialOrder
58+
59+
5060
Utils
5161
=====
5262

pymc_extras/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc_extras.distributions.histogram_utils import histogram_approximation
2727
from pymc_extras.distributions.multivariate import R2D2M2CP
2828
from pymc_extras.distributions.timeseries import DiscreteMarkovChain
29+
from pymc_extras.distributions.transforms import PartialOrder
2930

3031
__all__ = [
3132
"Chi",
@@ -37,4 +38,5 @@
3738
"R2D2M2CP",
3839
"Skellam",
3940
"histogram_approximation",
41+
"PartialOrder",
4042
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.distributions.transforms.partial_order import PartialOrder
2+
3+
__all__ = ["PartialOrder"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
17+
from pymc.logprob.transforms import Transform
18+
19+
__all__ = ["PartialOrder"]
20+
21+
22+
def dtype_minval(dtype):
23+
"""Find the minimum value for a given dtype"""
24+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
25+
26+
27+
def padded_where(x, to_len, padval=-1):
28+
"""A padded version of np.where"""
29+
w = np.where(x)
30+
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
31+
32+
33+
class PartialOrder(Transform):
34+
"""Create a PartialOrder transform
35+
36+
A more flexible version of the pymc ordered transform that
37+
allows specifying a (strict) partial order on the elements.
38+
39+
Examples
40+
--------
41+
.. code:: python
42+
43+
import numpy as np
44+
import pymc as pm
45+
import pymc_extras as pmx
46+
47+
# Define two partial orders on 4 elements
48+
# am[i,j] = 1 means i < j
49+
adj_mats = np.array([
50+
# 0 < {1, 2} < 3
51+
[[0, 1, 1, 0],
52+
[0, 0, 0, 1],
53+
[0, 0, 0, 1],
54+
[0, 0, 0, 0]],
55+
56+
# 1 < 0 < 3 < 2
57+
[[0, 0, 0, 1],
58+
[1, 0, 0, 0],
59+
[0, 0, 0, 0],
60+
[0, 0, 1, 0]],
61+
])
62+
63+
# Create the partial order from the adjacency matrices
64+
po = pmx.PartialOrder(adj_mats)
65+
66+
with pm.Model() as model:
67+
# Generate 3 samples from both partial orders
68+
pm.Normal("po_vals", shape=(3,2,4), transform=po,
69+
initval=po.initvals((3,2,4)))
70+
71+
idata = pm.sample()
72+
73+
# Verify that for first po, the zeroth element is always the smallest
74+
assert (idata.posterior['po_vals'][:,:,:,0,0] <
75+
idata.posterior['po_vals'][:,:,:,0,1:]).all()
76+
77+
# Verify that for second po, the second element is always the largest
78+
assert (idata.posterior['po_vals'][:,:,:,1,2] >=
79+
idata.posterior['po_vals'][:,:,:,1,:]).all()
80+
81+
Technical notes
82+
----------------
83+
Partial order needs to be strict, i.e. without equalities.
84+
A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
85+
Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
86+
where N is the number of nodes in the dag and D is the maximum
87+
in-degree of a node in the transitive reduction.
88+
"""
89+
90+
name = "partial_order"
91+
92+
def __init__(self, adj_mat):
93+
"""
94+
Initialize the PartialOrder transform
95+
96+
Parameters
97+
----------
98+
adj_mat: ndarray
99+
adjacency matrix for the DAG that generates the partial order,
100+
where ``adj_mat[i][j] = 1`` denotes ``i < j``.
101+
Note this also accepts multiple DAGs if RV is multidimensional
102+
"""
103+
104+
# Basic input checks
105+
if adj_mat.ndim < 2:
106+
raise ValueError("Adjacency matrix must have at least 2 dimensions")
107+
if adj_mat.shape[-2] != adj_mat.shape[-1]:
108+
raise ValueError("Adjacency matrix is not square")
109+
if adj_mat.min() != 0 or adj_mat.max() != 1:
110+
raise ValueError("Adjacency matrix must contain only 0s and 1s")
111+
112+
# Create index over the first ellipsis dimensions
113+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
114+
115+
# Transitive closure using Floyd-Warshall
116+
tc = adj_mat.astype(bool)
117+
for k in range(tc.shape[-1]):
118+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
119+
120+
# Check if the dag is acyclic
121+
if np.any(tc.diagonal(axis1=-2, axis2=-1)):
122+
raise ValueError("Partial order contains equalities")
123+
124+
# Transitive reduction using the closure
125+
# This gives the minimum description of the partial order
126+
# This is to minmax the input degree
127+
adj_mat = tc * (1 - np.matmul(tc, tc))
128+
129+
# Find the maximum in-degree of the reduced dag
130+
dag_idim = adj_mat.sum(axis=-2).max()
131+
132+
# Topological sort
133+
ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
134+
dm = adj_mat.copy()
135+
for i in range(adj_mat.shape[1]):
136+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
137+
nind = np.argmin(dm.sum(axis=-2), axis=-1)
138+
dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
139+
dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
140+
ts_inds[(*idx, i)] = nind
141+
self.ts_inds = ts_inds
142+
143+
# Change the dag to adjacency lists (with -1 for NA)
144+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
145+
self.dag = np.swapaxes(dag_T, -2, -1)
146+
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
147+
148+
def initvals(self, shape=None, lower=-1, upper=1):
149+
"""
150+
Create a set of appropriate initial values for the variable.
151+
NB! It is important that proper initial values are used,
152+
as only properly ordered values are in the range of the transform.
153+
154+
Parameters
155+
----------
156+
shape: tuple, default None
157+
shape of the initial values. If None, adj_mat[:-1] is used
158+
lower: float, default -1
159+
lower bound for the initial values
160+
upper: float, default 1
161+
upper bound for the initial values
162+
163+
Returns
164+
-------
165+
vals: ndarray
166+
initial values for the transformed variable
167+
"""
168+
169+
if shape is None:
170+
shape = self.dag.shape[:-1]
171+
172+
if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
173+
raise ValueError("Shape must match the shape of the adjacency matrix")
174+
175+
# Create the initial values
176+
vals = np.linspace(lower, upper, self.dag.shape[-2])
177+
inds = np.argsort(self.ts_inds, axis=-1)
178+
ivals = vals[inds]
179+
180+
# Expand the initial values to the extra dimensions
181+
extra_dims = shape[: -len(self.dag.shape[:-1])]
182+
ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
183+
184+
return ivals
185+
186+
def backward(self, value, *inputs):
187+
minv = dtype_minval(value.dtype)
188+
x = pt.concatenate(
189+
[pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
190+
)
191+
192+
# Indices to allow broadcasting the max over the last dimension
193+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
194+
idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
195+
196+
# Has to be done stepwise as next steps depend on previous values
197+
# Also has to be done in topological order, hence the ts_inds
198+
for i in range(self.dag.shape[-2]):
199+
tsi = self.ts_inds[..., i]
200+
if len(tsi.shape) == 0:
201+
tsi = int(tsi) # if shape 0, it's a scalar
202+
ni = (*idx, tsi) # i-th node in topological order
203+
eni = (Ellipsis, *ni)
204+
ist = self.is_start[ni]
205+
206+
mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
207+
x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
208+
return x[..., :-1]
209+
210+
def forward(self, value, *inputs):
211+
y = pt.zeros_like(value)
212+
213+
minv = dtype_minval(value.dtype)
214+
vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
215+
216+
# Indices to allow broadcasting the max over the last dimension
217+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
218+
idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
219+
220+
y = self.is_start * value + (1 - self.is_start) * (
221+
pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
222+
)
223+
224+
return y
225+
226+
def log_jac_det(self, value, *inputs):
227+
return pt.sum(value * (1 - self.is_start), axis=-1)

tests/distributions/test_transform.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pymc as pm
16+
17+
from pymc_extras.distributions.transforms import PartialOrder
18+
19+
20+
class TestPartialOrder:
21+
adj_mats = np.array(
22+
[
23+
# 0 < {1, 2} < 3
24+
[[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
25+
# 1 < 0 < 3 < 2
26+
[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
27+
]
28+
)
29+
30+
valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
31+
32+
# Test that forward and backward are inverses of eachother
33+
# And that it works when extra dimensions are added in data
34+
def test_forward_backward_dimensionality(self):
35+
po = PartialOrder(self.adj_mats)
36+
po0 = PartialOrder(self.adj_mats[0])
37+
vv = self.valid_values
38+
vv0 = self.valid_values[0]
39+
40+
testsets = [
41+
(vv, po),
42+
(po.initvals(), po),
43+
(vv0, po0),
44+
(po0.initvals(), po0),
45+
(np.tile(vv0, (2, 1)), po0),
46+
(np.tile(vv0, (2, 3, 2, 1)), po0),
47+
(np.tile(vv, (2, 3, 2, 1, 1)), po),
48+
]
49+
50+
for vv, po in testsets:
51+
fw = po.forward(vv)
52+
bw = po.backward(fw)
53+
np.testing.assert_allclose(bw.eval(), vv)
54+
55+
def test_sample_model(self):
56+
po = PartialOrder(self.adj_mats)
57+
with pm.Model() as model:
58+
x = pm.Normal(
59+
"x",
60+
size=(3, 2, 4),
61+
transform=po,
62+
initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
63+
)
64+
idata = pm.sample()
65+
66+
# Check that the order constraints are satisfied
67+
# Move chain, draw and "3" dimensions to the back
68+
xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
69+
x0 = xvs[0] # 0 < {1, 2} < 3
70+
assert (
71+
(x0[0] < x0[1]).all()
72+
and (x0[0] < x0[2]).all()
73+
and (x0[1] < x0[3]).all()
74+
and (x0[2] < x0[3]).all()
75+
)
76+
x1 = xvs[1] # 1 < 0 < 3 < 2
77+
assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()

0 commit comments

Comments
 (0)