Skip to content

Commit cab3ae9

Browse files
DekermanjianJonathan Dekermanjian
authored andcommitted
replaced all pymc potential with pymc censored (pymc-devs#750)
* replaced all pymc potential with pymc censored * removed gumbel_sf function that is not being used * added rng to samplers * put back seeds for sampling data observations --------- Co-authored-by: Jonathan Dekermanjian <[email protected]>
1 parent 9018430 commit cab3ae9

File tree

2 files changed

+377
-322
lines changed

2 files changed

+377
-322
lines changed

examples/survival_analysis/weibull_aft.ipynb

+325-290
Large diffs are not rendered by default.

examples/survival_analysis/weibull_aft.myst.md

+52-32
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ kernelspec:
88
display_name: pymc
99
language: python
1010
name: python3
11+
myst:
12+
substitutions:
13+
extra_dependencies: statsmodels
1114
---
1215

1316
(weibull_aft)=
@@ -25,15 +28,26 @@ import arviz as az
2528
import numpy as np
2629
import pymc as pm
2730
import pytensor.tensor as pt
28-
import statsmodels.api as sm
2931
3032
print(f"Running on PyMC v{pm.__version__}")
3133
```
3234

35+
:::{include} ../extra_installs.md
36+
:::
37+
38+
```{code-cell} ipython3
39+
# These dependencies need to be installed separately from PyMC
40+
import statsmodels.api as sm
41+
```
42+
3343
```{code-cell} ipython3
3444
%config InlineBackend.figure_format = 'retina'
45+
# These seeds are for sampling data observations
3546
RANDOM_SEED = 8927
3647
np.random.seed(RANDOM_SEED)
48+
# Set a seed for reproducibility of posterior results
49+
seed: int = sum(map(ord, "aft_weibull"))
50+
rng: np.random.Generator = np.random.default_rng(seed=seed)
3751
az.style.use("arviz-darkgrid")
3852
```
3953

@@ -71,7 +85,9 @@ censored[:5]
7185

7286
We have an unique problem when modelling censored data. Strictly speaking, we don't have any _data_ for censored values: we only know the _number_ of values that were censored. How can we include this information in our model?
7387

74-
One way do this is by making use of `pm.Potential`. The [PyMC2 docs](https://pymc-devs.github.io/pymc/modelbuilding.html#the-potential-class) explain its usage very well. Essentially, declaring `pm.Potential('x', logp)` will add `logp` to the log-likelihood of the model.
88+
One way do this is by making use of `pm.Potential`. The [PyMC2 docs](https://pymc-devs.github.io/pymc/modelbuilding.html#the-potential-class) explain its usage very well. Essentially, declaring `pm.Potential('x', logp)` will add `logp` to the log-likelihood of the model.
89+
90+
However, `pm.Potential` only effect probability based sampling this excludes using `pm.sample_prior_predictice` and `pm.sample_posterior_predictive`. We can overcome these limitations by using `pm.Censored` instead. We can model our right-censored data by defining the `upper` argument of `pm.Censored`.
7591

7692
+++
7793

@@ -80,36 +96,40 @@ One way do this is by making use of `pm.Potential`. The [PyMC2 docs](https://pym
8096
This parameterization is an intuitive, straightforward parameterization of the Weibull survival function. This is probably the first parameterization to come to one's mind.
8197

8298
```{code-cell} ipython3
83-
def weibull_lccdf(x, alpha, beta):
84-
"""Log complementary cdf of Weibull distribution."""
85-
return -((x / beta) ** alpha)
99+
# normalize the event time between 0 and 1
100+
y_norm = y / np.max(y)
101+
```
102+
103+
```{code-cell} ipython3
104+
# If censored then observed event time else maximum time
105+
right_censored = [x if x > 0 else np.max(y_norm) for x in y_norm * censored]
86106
```
87107

88108
```{code-cell} ipython3
89109
with pm.Model() as model_1:
90-
alpha_sd = 10.0
110+
alpha_sd = 1.0
91111
92-
mu = pm.Normal("mu", mu=0, sigma=100)
112+
mu = pm.Normal("mu", mu=0, sigma=1)
93113
alpha_raw = pm.Normal("a0", mu=0, sigma=0.1)
94114
alpha = pm.Deterministic("alpha", pt.exp(alpha_sd * alpha_raw))
95115
beta = pm.Deterministic("beta", pt.exp(mu / alpha))
116+
beta_backtransformed = pm.Deterministic("beta_backtransformed", beta * np.max(y))
96117
97-
y_obs = pm.Weibull("y_obs", alpha=alpha, beta=beta, observed=y[~censored])
98-
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], alpha, beta))
118+
latent = pm.Weibull.dist(alpha=alpha, beta=beta)
119+
y_obs = pm.Censored("Censored_likelihood", latent, upper=right_censored, observed=y_norm)
99120
```
100121

101122
```{code-cell} ipython3
102123
with model_1:
103-
# Change init to avoid divergences
104-
data_1 = pm.sample(target_accept=0.9, init="adapt_diag")
124+
idata_param1 = pm.sample(nuts_sampler="numpyro", random_seed=rng)
105125
```
106126

107127
```{code-cell} ipython3
108-
az.plot_trace(data_1, var_names=["alpha", "beta"])
128+
az.plot_trace(idata_param1, var_names=["alpha", "beta"])
109129
```
110130

111131
```{code-cell} ipython3
112-
az.summary(data_1, var_names=["alpha", "beta"], round_to=2)
132+
az.summary(idata_param1, var_names=["alpha", "beta", "beta_backtransformed"], round_to=2)
113133
```
114134

115135
## Parameterization 2
@@ -120,26 +140,26 @@ For more information, see [this Stan example model](https://github.com/stan-dev/
120140

121141
```{code-cell} ipython3
122142
with pm.Model() as model_2:
123-
alpha = pm.Normal("alpha", mu=0, sigma=10)
124-
r = pm.Gamma("r", alpha=1, beta=0.001, testval=0.25)
143+
alpha = pm.Normal("alpha", mu=0, sigma=1)
144+
r = pm.Gamma("r", alpha=2, beta=1)
125145
beta = pm.Deterministic("beta", pt.exp(-alpha / r))
146+
beta_backtransformed = pm.Deterministic("beta_backtransformed", beta * np.max(y))
126147
127-
y_obs = pm.Weibull("y_obs", alpha=r, beta=beta, observed=y[~censored])
128-
y_cens = pm.Potential("y_cens", weibull_lccdf(y[censored], r, beta))
148+
latent = pm.Weibull.dist(alpha=r, beta=beta)
149+
y_obs = pm.Censored("Censored_likelihood", latent, upper=right_censored, observed=y_norm)
129150
```
130151

131152
```{code-cell} ipython3
132153
with model_2:
133-
# Increase target_accept to avoid divergences
134-
data_2 = pm.sample(target_accept=0.9)
154+
idata_param2 = pm.sample(nuts_sampler="numpyro", random_seed=rng)
135155
```
136156

137157
```{code-cell} ipython3
138-
az.plot_trace(data_2, var_names=["r", "beta"])
158+
az.plot_trace(idata_param2, var_names=["r", "beta"])
139159
```
140160

141161
```{code-cell} ipython3
142-
az.summary(data_2, var_names=["r", "beta"], round_to=2)
162+
az.summary(idata_param2, var_names=["r", "beta", "beta_backtransformed"], round_to=2)
143163
```
144164

145165
## Parameterization 3
@@ -148,41 +168,41 @@ In this parameterization, we model the log-linear error distribution with a Gumb
148168

149169
```{code-cell} ipython3
150170
logtime = np.log(y)
171+
```
151172

152-
153-
def gumbel_sf(y, mu, sigma):
154-
"""Gumbel survival function."""
155-
return 1.0 - pt.exp(-pt.exp(-(y - mu) / sigma))
173+
```{code-cell} ipython3
174+
# If censored then observed event time else maximum time
175+
right_censored = [x if x > 0 else np.max(logtime) for x in logtime * censored]
156176
```
157177

158178
```{code-cell} ipython3
159179
with pm.Model() as model_3:
160-
s = pm.HalfNormal("s", tau=5.0)
180+
s = pm.HalfNormal("s", tau=3.0)
161181
gamma = pm.Normal("gamma", mu=0, sigma=5)
162182
163-
y_obs = pm.Gumbel("y_obs", mu=gamma, beta=s, observed=logtime[~censored])
164-
y_cens = pm.Potential("y_cens", gumbel_sf(y=logtime[censored], mu=gamma, sigma=s))
183+
latent = pm.Gumbel.dist(mu=gamma, beta=s)
184+
y_obs = pm.Censored("Censored_likelihood", latent, upper=right_censored, observed=logtime)
165185
```
166186

167187
```{code-cell} ipython3
168188
with model_3:
169-
# Change init to avoid divergences
170-
data_3 = pm.sample(init="adapt_diag")
189+
idata_param3 = pm.sample(tune=4000, draws=2000, nuts_sampler="numpyro", random_seed=rng)
171190
```
172191

173192
```{code-cell} ipython3
174-
az.plot_trace(data_3)
193+
az.plot_trace(idata_param3)
175194
```
176195

177196
```{code-cell} ipython3
178-
az.summary(data_3, round_to=2)
197+
az.summary(idata_param3, round_to=2)
179198
```
180199

181200
## Authors
182201

183202
- Originally collated by [Junpeng Lao](https://junpenglao.xyz/) on Apr 21, 2018. See original code [here](https://github.com/junpenglao/Planet_Sakaar_Data_Science/blob/65447fdb431c78b15fbeaef51b8c059f46c9e8d6/PyMC3QnA/discourse_1107.ipynb).
184203
- Authored and ported to Jupyter notebook by [George Ho](https://eigenfoo.xyz/) on Jul 15, 2018.
185204
- Updated for compatibility with PyMC v5 by Chris Fonnesbeck on Jan 16, 2023.
205+
- Updated to replace `pm.Potential` with `pm.Censored` by Jonathan Dekermanjian on Nov 25, 2024.
186206

187207
```{code-cell} ipython3
188208
%load_ext watermark

0 commit comments

Comments
 (0)