|
19 | 19 | __all__ = ["PartialOrder"]
|
20 | 20 |
|
21 | 21 |
|
22 |
| -# Find the minimum value for a given dtype |
23 | 22 | def dtype_minval(dtype):
|
| 23 | + """Find the minimum value for a given dtype""" |
24 | 24 | return np.iinfo(dtype).min if np.issubdtype(dtype, np.integer) else np.finfo(dtype).min
|
25 | 25 |
|
26 | 26 |
|
27 |
| -# A padded version of np.where |
28 | 27 | def padded_where(x, to_len, padval=-1):
|
| 28 | + """A padded version of np.where""" |
29 | 29 | w = np.where(x)
|
30 | 30 | return np.concatenate([w[0], np.full(to_len - len(w[0]), padval)])
|
31 | 31 |
|
32 | 32 |
|
33 |
| -# Partial order transform |
34 | 33 | class PartialOrder(Transform):
|
35 | 34 | """Create a PartialOrder transform
|
36 | 35 |
|
37 |
| - This is a more flexible version of the pymc ordered transform that |
| 36 | + A more flexible version of the pymc ordered transform that |
38 | 37 | allows specifying a (strict) partial order on the elements.
|
39 | 38 |
|
40 |
| - It works in O(N*D) in runtime, but takes O(N^3) in initialization, |
41 |
| - where N is the number of nodes in the dag and |
42 |
| - D is the maximum in-degree of a node in the transitive reduction. |
43 |
| -
|
| 39 | + Examples |
| 40 | + -------- |
| 41 | + .. code:: python |
| 42 | +
|
| 43 | + import numpy as np |
| 44 | + import pymc as pm |
| 45 | + import pymc_extras as pmx |
| 46 | +
|
| 47 | + # Define two partial orders on 4 elements |
| 48 | + # am[i,j] = 1 means i < j |
| 49 | + adj_mats = np.array([ |
| 50 | + # 0 < {1, 2} < 3 |
| 51 | + [[0, 1, 1, 0], |
| 52 | + [0, 0, 0, 1], |
| 53 | + [0, 0, 0, 1], |
| 54 | + [0, 0, 0, 0]], |
| 55 | +
|
| 56 | + # 1 < 0 < 3 < 2 |
| 57 | + [[0, 0, 0, 1], |
| 58 | + [1, 0, 0, 0], |
| 59 | + [0, 0, 0, 0], |
| 60 | + [0, 0, 1, 0]], |
| 61 | + ]) |
| 62 | +
|
| 63 | + # Create the partial order from the adjacency matrices |
| 64 | + po = pmx.PartialOrder(adj_mats) |
| 65 | +
|
| 66 | + with pm.Model() as model: |
| 67 | + # Generate 3 samples from both partial orders |
| 68 | + pm.Normal("po_vals", shape=(3,2,4), transform=po, |
| 69 | + initval=po.initvals((3,2,4))) |
| 70 | +
|
| 71 | + idata = pm.sample() |
| 72 | +
|
| 73 | + # Verify that for first po, the zeroth element is always the smallest |
| 74 | + assert (idata.posterior['po_vals'][:,:,:,0,0] < |
| 75 | + idata.posterior['po_vals'][:,:,:,0,1:]).all() |
| 76 | +
|
| 77 | + # Verify that for second po, the second element is always the largest |
| 78 | + assert (idata.posterior['po_vals'][:,:,:,1,2] >= |
| 79 | + idata.posterior['po_vals'][:,:,:,1,:]).all() |
| 80 | +
|
| 81 | + Technical notes |
| 82 | + ---------------- |
| 83 | + Partial order needs to be strict, i.e. without equalities. |
| 84 | + A DAG defining the partial order is sufficient, as transitive closure is automatically computed. |
| 85 | + Code works in O(N*D) in runtime, but takes O(N^3) in initialization, |
| 86 | + where N is the number of nodes in the dag and D is the maximum |
| 87 | + in-degree of a node in the transitive reduction. |
44 | 88 | """
|
45 | 89 |
|
46 | 90 | name = "partial_order"
|
47 | 91 |
|
48 | 92 | def __init__(self, adj_mat):
|
49 | 93 | """
|
| 94 | + Initialize the PartialOrder transform |
| 95 | +
|
50 | 96 | Parameters
|
51 | 97 | ----------
|
52 | 98 | adj_mat: ndarray
|
@@ -99,10 +145,43 @@ def __init__(self, adj_mat):
|
99 | 145 | self.dag = np.swapaxes(dag_T, -2, -1)
|
100 | 146 | self.is_start = np.all(self.dag[..., :, :] == -1, axis=-1)
|
101 | 147 |
|
102 |
| - def initvals(self, lower=-1, upper=1): |
| 148 | + def initvals(self, shape=None, lower=-1, upper=1): |
| 149 | + """ |
| 150 | + Create a set of appropriate initial values for the variable. |
| 151 | + NB! It is important that proper initial values are used, |
| 152 | + as only properly ordered values are in the range of the transform. |
| 153 | +
|
| 154 | + Parameters |
| 155 | + ---------- |
| 156 | + shape: tuple, default None |
| 157 | + shape of the initial values. If None, adj_mat[:-1] is used |
| 158 | + lower: float, default -1 |
| 159 | + lower bound for the initial values |
| 160 | + upper: float, default 1 |
| 161 | + upper bound for the initial values |
| 162 | +
|
| 163 | + Returns |
| 164 | + ------- |
| 165 | + vals: ndarray |
| 166 | + initial values for the transformed variable |
| 167 | + """ |
| 168 | + |
| 169 | + if shape is None: |
| 170 | + shape = self.dag.shape[:-1] |
| 171 | + |
| 172 | + if shape[-len(self.dag.shape[:-1]) :] != self.dag.shape[:-1]: |
| 173 | + raise ValueError("Shape must match the shape of the adjacency matrix") |
| 174 | + |
| 175 | + # Create the initial values |
103 | 176 | vals = np.linspace(lower, upper, self.dag.shape[-2])
|
104 | 177 | inds = np.argsort(self.ts_inds, axis=-1)
|
105 |
| - return vals[inds] |
| 178 | + ivals = vals[inds] |
| 179 | + |
| 180 | + # Expand the initial values to the extra dimensions |
| 181 | + extra_dims = shape[: -len(self.dag.shape[:-1])] |
| 182 | + ivals = np.tile(ivals, extra_dims + tuple([1] * len(self.dag.shape[:-1]))) |
| 183 | + |
| 184 | + return ivals |
106 | 185 |
|
107 | 186 | def backward(self, value, *inputs):
|
108 | 187 | minv = dtype_minval(value.dtype)
|
|
0 commit comments