Skip to content

Commit e89c788

Browse files
committed
Improved documentation with an example
1 parent b38c2f0 commit e89c788

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

pymc_extras/distributions/transforms/partial_order.py

+89-10
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,80 @@
1919
__all__ = ["PartialOrder"]
2020

2121

22-
# Find the minimum value for a given dtype
2322
def dtype_minval(dtype):
23+
"""Find the minimum value for a given dtype"""
2424
return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
2525

2626

27-
# A padded version of np.where
2827
def padded_where(x, to_len, padval=-1):
28+
"""A padded version of np.where"""
2929
w = np.where(x)
3030
return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
3131

3232

33-
# Partial order transform
3433
class PartialOrder(Transform):
3534
"""Create a PartialOrder transform
3635
37-
This is a more flexible version of the pymc ordered transform that
36+
A more flexible version of the pymc ordered transform that
3837
allows specifying a (strict) partial order on the elements.
3938
40-
It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41-
where N is the number of nodes in the dag and
42-
D is the maximum in-degree of a node in the transitive reduction.
43-
39+
Examples
40+
--------
41+
.. code:: python
42+
43+
import numpy as np
44+
import pymc as pm
45+
import pymc_extras as pmx
46+
47+
# Define two partial orders on 4 elements
48+
# am[i,j] = 1 means i < j
49+
adj_mats = np.array([
50+
# 0 < {1, 2} < 3
51+
[[0, 1, 1, 0],
52+
[0, 0, 0, 1],
53+
[0, 0, 0, 1],
54+
[0, 0, 0, 0]],
55+
56+
# 1 < 0 < 3 < 2
57+
[[0, 0, 0, 1],
58+
[1, 0, 0, 0],
59+
[0, 0, 0, 0],
60+
[0, 0, 1, 0]],
61+
])
62+
63+
# Create the partial order from the adjacency matrices
64+
po = pmx.PartialOrder(adj_mats)
65+
66+
with pm.Model() as model:
67+
# Generate 3 samples from both partial orders
68+
pm.Normal("po_vals", shape=(3,2,4), transform=po,
69+
initval=po.initvals((3,2,4)))
70+
71+
idata = pm.sample()
72+
73+
# Verify that for first po, the zeroth element is always the smallest
74+
assert (idata.posterior['po_vals'][:,:,:,0,0] <
75+
idata.posterior['po_vals'][:,:,:,0,1:]).all()
76+
77+
# Verify that for second po, the second element is always the largest
78+
assert (idata.posterior['po_vals'][:,:,:,1,2] >=
79+
idata.posterior['po_vals'][:,:,:,1,:]).all()
80+
81+
Technical notes
82+
----------------
83+
Partial order needs to be strict, i.e. without equalities.
84+
A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
85+
Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
86+
where N is the number of nodes in the dag and D is the maximum
87+
in-degree of a node in the transitive reduction.
4488
"""
4589

4690
name = "partial_order"
4791

4892
def __init__(self, adj_mat):
4993
"""
94+
Initialize the PartialOrder transform
95+
5096
Parameters
5197
----------
5298
adj_mat: ndarray
@@ -99,10 +145,43 @@ def __init__(self, adj_mat):
99145
self.dag = np.swapaxes(dag_T, -2, -1)
100146
self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
101147

102-
def initvals(self, lower=-1, upper=1):
148+
def initvals(self, shape=None, lower=-1, upper=1):
149+
"""
150+
Create a set of appropriate initial values for the variable.
151+
NB! It is important that proper initial values are used,
152+
as only properly ordered values are in the range of the transform.
153+
154+
Parameters
155+
----------
156+
shape: tuple, default None
157+
shape of the initial values. If None, adj_mat[:-1] is used
158+
lower: float, default -1
159+
lower bound for the initial values
160+
upper: float, default 1
161+
upper bound for the initial values
162+
163+
Returns
164+
-------
165+
vals: ndarray
166+
initial values for the transformed variable
167+
"""
168+
169+
if shape is None:
170+
shape = self.dag.shape[:-1]
171+
172+
if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]:
173+
raise ValueError("Shape must match the shape of the adjacency matrix")
174+
175+
# Create the initial values
103176
vals = np.linspace(lower, upper, self.dag.shape[-2])
104177
inds = np.argsort(self.ts_inds, axis=-1)
105-
return vals[inds]
178+
ivals = vals[inds]
179+
180+
# Expand the initial values to the extra dimensions
181+
extra_dims = shape[: -len(self.dag.shape[:-1])]
182+
ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1])))
183+
184+
return ivals
106185

107186
def backward(self, value, *inputs):
108187
minv = dtype_minval(value.dtype)

tests/distributions/test_transform.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,17 @@ def test_forward_backward_dimensionality(self):
5555
def test_sample_model(self):
5656
po = PartialOrder(self.adj_mats)
5757
with pm.Model() as model:
58-
x = pm.Normal("x", size=(2, 4), transform=po, initval=po.initvals(-1, 1))
58+
x = pm.Normal(
59+
"x",
60+
size=(3, 2, 4),
61+
transform=po,
62+
initval=po.initvals(shape=(3, 2, 4), lower=-1, upper=1),
63+
)
5964
idata = pm.sample()
6065

6166
# Check that the order constraints are satisfied
62-
xvs = idata.posterior.x.values.transpose(2, 3, 0, 1)
67+
# Move chain, draw and "3" dimensions to the back
68+
xvs = idata.posterior.x.values.transpose(3, 4, 0, 1, 2)
6369
x0 = xvs[0] # 0 < {1, 2} < 3
6470
assert (
6571
(x0[0] < x0[1]).all()

0 commit comments

Comments
 (0)