Skip to content

Commit 5421a34

Browse files
committed
Added a PartialOrder transform
1 parent 7d62c53 commit 5421a34

File tree

3 files changed

+218
-0
lines changed

3 files changed

+218
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from pymc_extras.distributions.transforms.partial_order import PartialOrder
2+
3+
__all__ = ["PartialOrder"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor.tensor as pt
16+
17+
from pymc.logprob.transforms import Transform
18+
19+
__all__ = ["PartialOrder"]
20+
21+
22+
# Find the minimum value for a given dtype
23+
def dtype_minval(dtype):
24+
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
25+
26+
27+
# A padded version of np.where
28+
def padded_where(x, to_len, padval=-1):
29+
w = np.where(x)
30+
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
31+
32+
33+
# Partial order transform
34+
class PartialOrder(Transform):
35+
name = "partial_order"
36+
37+
def __init__(self, adj_mat):
38+
"""Create a PartialOrder transform
39+
40+
This is a more flexible version of the pymc ordered transform that
41+
allows specifying a (strict) partial order on the elements.
42+
43+
It works in O(N*D) in runtime, but takes O(N^3) in initialization,
44+
where N is the number of nodes in the dag and
45+
D is the maximum in-degree of a node in the transitive reduction.
46+
47+
Parameters
48+
----------
49+
adj_mat: adjacency matrix for the DAG that generates the partial order,
50+
where adj_mat[i][j] = 1 denotes i<j.
51+
Note this also accepts multiple DAGs if RV is multidimensional
52+
"""
53+
54+
# Basic input checks
55+
if adj_mat.ndim < 2:
56+
raise ValueError("Adjacency matrix must have at least 2 dimensions")
57+
if adj_mat.shape[-2] != adj_mat.shape[-1]:
58+
raise ValueError("Adjacency matrix is not square")
59+
if adj_mat.min() != 0 or adj_mat.max() != 1:
60+
raise ValueError("Adjacency matrix must contain only 0s and 1s")
61+
62+
# Create index over the first ellipsis dimensions
63+
idx = np.ix_(*[np.arange(s) for s in adj_mat.shape[:-2]])
64+
65+
# Transitive closure using Floyd-Warshall
66+
tc = adj_mat.astype(bool)
67+
for k in range(tc.shape[-1]):
68+
tc |= np.logical_and(tc[..., :, k, None], tc[..., None, k, :])
69+
70+
# Check if the dag is acyclic
71+
if np.any(tc.diagonal(axis1=-2, axis2=-1)):
72+
raise ValueError("Partial order contains equalities")
73+
74+
# Transitive reduction using the closure
75+
# This gives the minimum description of the partial order
76+
# This is to minmax the input degree
77+
adj_mat = tc * (1 - np.matmul(tc, tc))
78+
79+
# Find the maximum in-degree of the reduced dag
80+
dag_idim = adj_mat.sum(axis=-2).max()
81+
82+
# Topological sort
83+
ts_inds = np.zeros(adj_mat.shape[:-1], dtype=int)
84+
dm = adj_mat.copy()
85+
for i in range(adj_mat.shape[1]):
86+
assert dm.sum(axis=-2).min() == 0 # DAG is acyclic
87+
nind = np.argmin(dm.sum(axis=-2), axis=-1)
88+
dm[(*idx, slice(None), nind)] = 1 # Make nind not show up again
89+
dm[(*idx, nind, slice(None))] = 0 # Allow it's children to show
90+
ts_inds[(*idx, i)] = nind
91+
self.ts_inds = ts_inds
92+
93+
# Change the dag to adjacency lists (with -1 for NA)
94+
dag_T = np.apply_along_axis(padded_where, axis=-2, arr=adj_mat, padval=-1, to_len=dag_idim)
95+
self.dag = np.swapaxes(dag_T, -2, -1)
96+
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
97+
98+
def initvals(self, lower=-1, upper=1):
99+
vals = np.linspace(lower, upper, self.dag.shape[-2])
100+
inds = np.argsort(self.ts_inds, axis=-1)
101+
return vals[inds]
102+
103+
def backward(self, value, *inputs):
104+
minv = dtype_minval(value.dtype)
105+
x = pt.concatenate(
106+
[pt.zeros_like(value), pt.full(value.shape[:-1], minv)[..., None]], axis=-1
107+
)
108+
109+
# Indices to allow broadcasting the max over the last dimension
110+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
111+
idx2 = tuple(np.tile(i[:, None], self.dag.shape[-1]) for i in idx)
112+
113+
# Has to be done stepwise as next steps depend on previous values
114+
# Also has to be done in topological order, hence the ts_inds
115+
for i in range(self.dag.shape[-2]):
116+
tsi = self.ts_inds[..., i]
117+
if len(tsi.shape) == 0:
118+
tsi = int(tsi) # if shape 0, it's a scalar
119+
ni = (*idx, tsi) # i-th node in topological order
120+
eni = (Ellipsis, *ni)
121+
ist = self.is_start[ni]
122+
123+
mval = pt.max(x[(Ellipsis, *idx2, self.dag[ni])], axis=-1)
124+
x = pt.set_subtensor(x[eni], ist * value[eni] + (1 - ist) * (mval + pt.exp(value[eni])))
125+
return x[..., :-1]
126+
127+
def forward(self, value, *inputs):
128+
y = pt.zeros_like(value)
129+
130+
minv = dtype_minval(value.dtype)
131+
vx = pt.concatenate([value, pt.full(value.shape[:-1], minv)[..., None]], axis=-1)
132+
133+
# Indices to allow broadcasting the max over the last dimension
134+
idx = np.ix_(*[np.arange(s) for s in self.dag.shape[:-2]])
135+
idx = tuple(np.tile(i[:, None, None], self.dag.shape[-2:]) for i in idx)
136+
137+
y = self.is_start * value + (1 - self.is_start) * (
138+
pt.log(value - pt.max(vx[(Ellipsis, *idx, self.dag[..., :])], axis=-1))
139+
)
140+
141+
return y
142+
143+
def log_jac_det(self, value, *inputs):
144+
return pt.sum(value * (1 - self.is_start), axis=-1)

tests/distributions/test_transform.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright 2025 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pymc as pm
16+
17+
from pymc_extras.distributions.transforms import PartialOrder
18+
19+
20+
class TestPartialOrder:
21+
adj_mats = np.array(
22+
[
23+
# 0 < {1, 2} < 3
24+
[[0, 1, 1, 0], [0, 0, 0, 1], [0, 0, 0, 1], [0, 0, 0, 0]],
25+
# 1 < 0 < 3 < 2
26+
[[0, 0, 0, 1], [1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 0]],
27+
]
28+
)
29+
30+
valid_values = np.array([[0, 2, 1, 3], [1, 0, 3, 2]], dtype=float)
31+
32+
# Test that forward and backward are inverses of eachother
33+
# And that it works when extra dimensions are added in data
34+
def test_forward_backward_dimensionality(self):
35+
po = PartialOrder(self.adj_mats)
36+
po0 = PartialOrder(self.adj_mats[0])
37+
vv = self.valid_values
38+
vv0 = self.valid_values[0]
39+
40+
testsets = [
41+
(vv, po),
42+
(po.initvals(), po),
43+
(vv0, po0),
44+
(po0.initvals(), po0),
45+
(np.tile(vv0, (2, 1)), po0),
46+
(np.tile(vv0, (2, 3, 2, 1)), po0),
47+
(np.tile(vv, (2, 3, 2, 1, 1)), po),
48+
]
49+
50+
for vv, po in testsets:
51+
fw = po.forward(vv)
52+
bw = po.backward(fw)
53+
np.testing.assert_allclose(bw.eval(), vv)
54+
55+
def test_sample_model(self):
56+
po = PartialOrder(self.adj_mats)
57+
with pm.Model() as model:
58+
x = pm.Normal("x", size=(2, 4), transform=po, initval=po.initvals(-1, 1))
59+
idata = pm.sample()
60+
61+
# Check that the order constraints are satisfied
62+
xvs = idata.posterior.x.values.transpose(2, 3, 0, 1)
63+
x0 = xvs[0] # 0 < {1, 2} < 3
64+
assert (
65+
(x0[0] < x0[1]).all()
66+
and (x0[0] < x0[2]).all()
67+
and (x0[1] < x0[3]).all()
68+
and (x0[2] < x0[3]).all()
69+
)
70+
x1 = xvs[1] # 1 < 0 < 3 < 2
71+
assert (x1[1] < x1[0]).all() and (x1[0] < x1[3]).all() and (x1[3] < x1[2]).all()

0 commit comments

Comments
 (0)