Skip to content

Commit 299e17b

Browse files
mhaurupenelopeysmtorfjeldeyebai
authored
Accumulators, stage 1 (#885)
* Release 0.36 * AbstractPPL 0.11 + change prefixing behaviour (#830) * AbstractPPL 0.11; change prefixing behaviour * Use DynamicPPL.prefix rather than overloading * Remove VarInfo(VarInfo, params) (#870) * Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879) * Unify {Untyped,Typed}{Vector,}VarInfo constructors * Update invocations * NTVarInfo * Fix tests * More fixes * Fixes * Fixes * Fixes * Use lowercase functions, don't deprecate VarInfo * Rewrite VarInfo docstring * Fix methods * Fix methods (really) * Draft of accumulators * Fix some variable names * Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!! * Map rather than broadcast Co-authored-by: Tor Erlend Fjelde <[email protected]> * Start documenting accumulators * Use Val{symbols} instead of AccTypes to index * More documentation for accumulators * Link varinfo by default in AD testing utilities; make test suite run on linked varinfos (#890) * Link VarInfo by default * Tweak interface * Fix tests * Fix interface so that callers can inspect results * Document * Fix tests * Fix changelog * Test linked varinfos Closes #891 * Fix docstring + use AbstractFloat * Fix resetlogp!! and type stability for accumulators * Fix type rigidity of LogProbs and NumProduce * Fix uses of getlogp and other assorted issues * setaccs!! nicer interface and logdensity function fixes * Revert back to calling the macro @addlogprob! * Remove a dead test * Clarify a comment * Implement split/combine for PointwiseLogdensityAccumulator * Switch ThreadSafeVarInfo.accs_by_thread to be a tuple * Fix `condition` and `fix` in submodels (#892) * Fix conditioning in submodels * Simplify contextual_isassumption * Add documentation * Fix some tests * Add tests; fix a bunch of nested submodel issues * Fix fix as well * Fix doctests * Add unit tests for new functions * Add changelog entry * Update changelog Co-authored-by: Hong Ge <[email protected]> * Finish docs * Add a test for conditioning submodel via arguments * Clean new tests up a bit * Fix for VarNames with non-identity lenses * Apply suggestions from code review Co-authored-by: Markus Hauru <[email protected]> * Apply suggestions from code review * Make PrefixContext contain a varname rather than symbol (#896) --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Markus Hauru <[email protected]> * Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting in (Simple)VarInfo * Improve accumulator docs * Add test/accumulators.jl * Docs fixes * Various small fixes * Make DynamicTransformation not use accumulators other than LogPrior * Fix variable order and name of map_accumulator!! * Typo fixing * Small improvement to ThreadSafeVarInfo * Fix demo_dot_assume_observe_submodel prefixing * Typo fixing * Miscellaneous small fixes * HISTORY entry and more miscellanea * Add more tests for accumulators * Improve accumulators docstrings * Fix a typo * Expand HISTORY entry * Add accumulators to API docs * Remove unexported functions from API docs * Add NamedTuple methods for get/set/acclogp * Fix setlogp!! with single scalar to error * Export AbstractAccumulator, fix a docs typo * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce * Type bound log prob accumulators with T<:Real * Add @addlogprior! and @addloglikelihood! * Apply suggestions from code review Co-authored-by: Penelope Yong <[email protected]> * Move default accumulators to default_accumulators.jl * Fix some tests * Introduce default_accumulators() * Go back to only having @addlogprob! * Fix tilde_observe!! prefixing * Fix default_accumulators internal type * Make unflatten more type stable, and add a test for it * Always print all benchmark results * Move NumProduce VI functions to abstract_varinfo.jl --------- Co-authored-by: Penelope Yong <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 8135113 commit 299e17b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1768
-915
lines changed

HISTORY.md

+18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
# DynamicPPL Changelog
22

3+
## 0.37.0
4+
5+
**Breaking changes**
6+
7+
### Accumulators
8+
9+
This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:
10+
11+
- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`.
12+
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
13+
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
14+
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
15+
- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
16+
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
17+
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
18+
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
19+
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.
20+
321
## 0.36.0
422

523
**Breaking changes**

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2121
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2222
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2323
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
24+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2425
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2627
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
@@ -68,6 +69,7 @@ MCMCChains = "6"
6869
MacroTools = "0.5.6"
6970
Mooncake = "0.4.95"
7071
OrderedCollections = "1"
72+
Printf = "1.10"
7173
Random = "1.6"
7274
Requires = "1"
7375
Statistics = "1"

benchmarks/benchmarks.jl

+1
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,5 @@ PrettyTables.pretty_table(
100100
header=header,
101101
tf=PrettyTables.tf_markdown,
102102
formatters=ft_printf("%.1f", [6, 7]),
103+
crop=:none, # Always print the whole table, even if it doesn't fit in the terminal.
103104
)

docs/src/api.md

+27-11
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ returned(::Model)
160160

161161
## Utilities
162162

163-
It is possible to manually increase (or decrease) the accumulated log density from within a model function.
163+
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.
164164

165165
```@docs
166166
@addlogprob!
@@ -328,9 +328,9 @@ The following functions were used for sequential Monte Carlo methods.
328328

329329
```@docs
330330
get_num_produce
331-
set_num_produce!
332-
increment_num_produce!
333-
reset_num_produce!
331+
set_num_produce!!
332+
increment_num_produce!!
333+
reset_num_produce!!
334334
setorder!
335335
set_retained_vns_del!
336336
```
@@ -345,6 +345,22 @@ Base.empty!
345345
SimpleVarInfo
346346
```
347347

348+
### Accumulators
349+
350+
The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.
351+
352+
```@docs
353+
AbstractAccumulator
354+
```
355+
356+
DynamicPPL provides the following default accumulators.
357+
358+
```@docs
359+
LogPriorAccumulator
360+
LogLikelihoodAccumulator
361+
NumProduceAccumulator
362+
```
363+
348364
### Common API
349365

350366
#### Accumulation of log-probabilities
@@ -353,6 +369,13 @@ SimpleVarInfo
353369
getlogp
354370
setlogp!!
355371
acclogp!!
372+
getlogjoint
373+
getlogprior
374+
setlogprior!!
375+
acclogprior!!
376+
getloglikelihood
377+
setloglikelihood!!
378+
accloglikelihood!!
356379
resetlogp!!
357380
```
358381

@@ -427,9 +450,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
427450
```@docs
428451
SamplingContext
429452
DefaultContext
430-
LikelihoodContext
431-
PriorContext
432-
MiniBatchContext
433453
PrefixContext
434454
ConditionContext
435455
```
@@ -476,7 +496,3 @@ DynamicPPL.Experimental.is_suitable_varinfo
476496
```@docs
477497
tilde_assume
478498
```
479-
480-
```@docs
481-
tilde_observe
482-
```

ext/DynamicPPLMCMCChainsExt.jl

+6-6
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,18 @@ end
4848
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
4949
in `chain`, and return the resulting `Chains`.
5050
51-
The `model` passed to `predict` is often different from the one used to generate `chain`.
52-
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53-
data points), while the model you pass to `predict` may mark these same variables as missing
54-
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
51+
The `model` passed to `predict` is often different from the one used to generate `chain`.
52+
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
53+
data points), while the model you pass to `predict` may mark these same variables as missing
54+
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
5555
simulate what new, unobserved data might look like, given your posterior beliefs.
5656
5757
For each parameter configuration in `chain`:
5858
1. All random variables present in `chain` are fixed to their sampled values.
5959
2. Any variables not included in `chain` are sampled from their prior distributions.
6060
6161
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
62-
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
62+
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
6363
predictive distribution.
6464
6565
# Examples
@@ -124,7 +124,7 @@ function DynamicPPL.predict(
124124
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
125125
)
126126

127-
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
127+
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))
128128
end
129129

130130
chain_result = reduce(

src/DynamicPPL.jl

+20-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using Bijectors
66
using Compat
77
using Distributions
88
using OrderedCollections: OrderedCollections, OrderedDict
9+
using Printf: Printf
910

1011
using AbstractMCMC: AbstractMCMC
1112
using ADTypes: ADTypes
@@ -46,17 +47,28 @@ import Base:
4647
export AbstractVarInfo,
4748
VarInfo,
4849
SimpleVarInfo,
50+
AbstractAccumulator,
51+
LogLikelihoodAccumulator,
52+
LogPriorAccumulator,
53+
NumProduceAccumulator,
4954
push!!,
5055
empty!!,
5156
subset,
5257
getlogp,
58+
getlogjoint,
59+
getlogprior,
60+
getloglikelihood,
5361
setlogp!!,
62+
setlogprior!!,
63+
setloglikelihood!!,
5464
acclogp!!,
65+
acclogprior!!,
66+
accloglikelihood!!,
5567
resetlogp!!,
5668
get_num_produce,
57-
set_num_produce!,
58-
reset_num_produce!,
59-
increment_num_produce!,
69+
set_num_produce!!,
70+
reset_num_produce!!,
71+
increment_num_produce!!,
6072
set_retained_vns_del!,
6173
is_flagged,
6274
set_flag!,
@@ -92,15 +104,10 @@ export AbstractVarInfo,
92104
# Contexts
93105
SamplingContext,
94106
DefaultContext,
95-
LikelihoodContext,
96-
PriorContext,
97-
MiniBatchContext,
98107
PrefixContext,
99108
ConditionContext,
100109
assume,
101-
observe,
102110
tilde_assume,
103-
tilde_observe,
104111
# Pseudo distributions
105112
NamedDist,
106113
NoDist,
@@ -146,6 +153,9 @@ macro prob_str(str)
146153
))
147154
end
148155

156+
# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo
157+
# has to implement. Not sure what the full list is for parameters values, but for
158+
# accumulators we only need `getaccs` and `setaccs!!`.
149159
"""
150160
AbstractVarInfo
151161
@@ -166,6 +176,8 @@ include("varname.jl")
166176
include("distribution_wrappers.jl")
167177
include("contexts.jl")
168178
include("varnamedvector.jl")
179+
include("accumulators.jl")
180+
include("default_accumulators.jl")
169181
include("abstract_varinfo.jl")
170182
include("threadsafe.jl")
171183
include("varinfo.jl")

0 commit comments

Comments
 (0)