Skip to content

Accumulators, stage 1 #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 67 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
061acbe
Release 0.36
penelopeysm Mar 5, 2025
1496868
Merge branch 'main' into breaking
penelopeysm Mar 20, 2025
324e623
Merge branch 'main' into breaking
penelopeysm Mar 22, 2025
bb59885
Merge branch 'main' into breaking
penelopeysm Mar 25, 2025
fc32398
AbstractPPL 0.11 + change prefixing behaviour (#830)
penelopeysm Mar 28, 2025
cc5e581
Remove VarInfo(VarInfo, params) (#870)
penelopeysm Mar 28, 2025
b9c368b
Unify `{untyped,typed}_{vector_,}varinfo` constructor functions (#879)
penelopeysm Apr 9, 2025
1b8f555
Merge remote-tracking branch 'origin/main' into breaking
mhauru Apr 9, 2025
ae9b1cd
Draft of accumulators
mhauru Feb 24, 2025
4fb0bf4
Fix some variable names
mhauru Apr 10, 2025
e410f47
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 11, 2025
97788bd
Fix pointwise_logdensities, gut tilde_observe, remove resetlogp!!
mhauru Apr 11, 2025
7fe03ec
Map rather than broadcast
mhauru Apr 15, 2025
5ba3530
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 15, 2025
d49f7be
Start documenting accumulators
mhauru Apr 15, 2025
28bbf1c
Use Val{symbols} instead of AccTypes to index
mhauru Apr 15, 2025
a0ed665
More documentation for accumulators
mhauru Apr 15, 2025
be27636
Link varinfo by default in AD testing utilities; make test suite run …
penelopeysm Apr 16, 2025
e6453fe
Fix resetlogp!! and type stability for accumulators
mhauru Apr 16, 2025
c59400d
Fix type rigidity of LogProbs and NumProduce
mhauru Apr 16, 2025
47033ce
Fix uses of getlogp and other assorted issues
mhauru Apr 17, 2025
8b841c9
setaccs!! nicer interface and logdensity function fixes
mhauru Apr 17, 2025
3ee3989
Revert back to calling the macro @addlogprob!
mhauru Apr 22, 2025
13163f2
Remove a dead test
mhauru Apr 22, 2025
37dd6dd
Clarify a comment
mhauru Apr 22, 2025
d7013b6
Implement split/combine for PointwiseLogdensityAccumulator
mhauru Apr 22, 2025
40d4caa
Switch ThreadSafeVarInfo.accs_by_thread to be a tuple
mhauru Apr 22, 2025
ff5f2cb
Fix `condition` and `fix` in submodels (#892)
penelopeysm Apr 23, 2025
c68f1bb
Merge remote-tracking branch 'origin/main' into breaking
penelopeysm Apr 23, 2025
13da08a
Revert ThreadSafeVarInfo back to Vectors and fix some AD type casting…
mhauru Apr 24, 2025
d52feec
Merge remote-tracking branch 'origin/breaking' into mhauru/custom-acc…
mhauru Apr 24, 2025
221e797
Improve accumulator docs
mhauru Apr 24, 2025
1dbcb2c
Add test/accumulators.jl
mhauru Apr 24, 2025
e1b70e0
Docs fixes
mhauru Apr 24, 2025
3f195e5
Various small fixes
mhauru Apr 24, 2025
68b974a
Make DynamicTransformation not use accumulators other than LogPrior
mhauru Apr 24, 2025
2b405d9
Fix variable order and name of map_accumulator!!
mhauru Apr 24, 2025
00cd304
Typo fixing
mhauru Apr 24, 2025
6d1048d
Small improvement to ThreadSafeVarInfo
mhauru Apr 24, 2025
4fef20f
Fix demo_dot_assume_observe_submodel prefixing
mhauru Apr 24, 2025
905b874
Merge branch 'breaking' into mhauru/custom-accumulators
mhauru Apr 24, 2025
557954a
Typo fixing
mhauru Apr 24, 2025
6f702c9
Miscellaneous small fixes
mhauru Apr 25, 2025
f748775
HISTORY entry and more miscellanea
mhauru Apr 25, 2025
5f4a532
Add more tests for accumulators
mhauru Apr 25, 2025
31967fd
Improve accumulators docstrings
mhauru Apr 25, 2025
ad2f564
Fix a typo
mhauru Apr 25, 2025
10b4f2f
Expand HISTORY entry
mhauru Apr 25, 2025
d2b670d
Add accumulators to API docs
mhauru Apr 25, 2025
8241d12
Remove unexported functions from API docs
mhauru Apr 25, 2025
7b7a3e2
Add NamedTuple methods for get/set/acclogp
mhauru Apr 25, 2025
0b08237
Fix setlogp!! with single scalar to error
mhauru Apr 25, 2025
2a4b874
Export AbstractAccumulator, fix a docs typo
mhauru Apr 25, 2025
c1e90f7
Apply suggestions from code review
mhauru Apr 28, 2025
cb1c6c6
Rename LogPrior -> LogPriorAccumulator, and Likelihood and NumProduce
mhauru Apr 28, 2025
00ef0cf
Type bound log prob accumulators with T<:Real
mhauru Apr 28, 2025
14f4788
Add @addlogprior! and @addloglikelihood!
mhauru Apr 28, 2025
c4ee4ec
Apply suggestions from code review
mhauru Apr 30, 2025
7ad9450
Move default accumulators to default_accumulators.jl
mhauru Apr 30, 2025
c5e2a6b
Fix some tests
mhauru Apr 30, 2025
fb09acc
Introduce default_accumulators()
mhauru Apr 30, 2025
e324c9b
Go back to only having @addlogprob!
mhauru Apr 30, 2025
6b7b9f8
Fix tilde_observe!! prefixing
mhauru Apr 30, 2025
048178b
Fix default_accumulators internal type
mhauru Apr 30, 2025
6437801
Make unflatten more type stable, and add a test for it
mhauru May 1, 2025
bf95169
Always print all benchmark results
mhauru May 1, 2025
efc7c53
Move NumProduce VI functions to abstract_varinfo.jl
mhauru May 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# DynamicPPL Changelog

## 0.37.0

**Breaking changes**

### Accumulators

This release overhauls how VarInfo objects track variables such as the log joint probability. The new approach is to use what we call accumulators: Objects that the VarInfo carries on it that may change their state at each `tilde_assume!!` and `tilde_observe!!` call based on the value of the variable in question. They replace both variables that were previously hard-coded in the `VarInfo` object (`logp` and `num_produce`) and some contexts. This brings with it a number of breaking changes:

- `PriorContext` and `LikelihoodContext` no longer exist. By default, a `VarInfo` tracks both the log prior and the log likelihood separately, and they can be accessed with `getlogprior` and `getloglikelihood`. If you want to execute a model while only accumulating one of the two (to save clock cycles), you can do so by creating a `VarInfo` that only has one accumulator in it, e.g. `varinfo = setaccs!!(varinfo, (LogPriorAccumulator(),))`.
- `MiniBatchContext` does not exist anymore. It can be replaced by creating and using a custom accumulator that replaces the default `LikelihoodContext`. We may introduce such an accumulator in DynamicPPL in the future, but for now you'll need to do it yourself.
- `tilde_observe` and `observe` have been removed. `tilde_observe!!` still exists, and any contexts should modify its behaviour. We may further rework the call stack under `tilde_observe!!` in the near future.
- `tilde_assume` no longer returns the log density of the current assumption as its second return value. We may further rework the `tilde_assume!!` call stack as well.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why we can't remove tilde_assume like tilde_observe and observe?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on my list to do, but I didn't want to put in the same PR since I didn't have to. (I had to do tilde_observe because of complications with PointwiseLogdensityAccumulator). Or, more precisely, what's on my list is to revisit the whole call stack for both tilde_assume!! and tilde_observe!! and see what the best way to do things is.

- For literal observation statements like `0.0 ~ Normal(blahblah)` we used to call `tilde_observe!!` without the `vn` argument. This method no longer exists. Rather we call `tilde_observe!!` with `vn` set to `nothing`.
- `set/reset/increment_num_produce!` have become `set/reset/increment_num_produce!!` (note the second exclamation mark). They are no longer guaranteed to modify the `VarInfo` in place, and one should always use the return value.
- `@addlogprob!` now _always_ adds to the log likelihood. Previously it added to the log probability that the execution context specified, e.g. the log prior when using `PriorContext`.
- `getlogp` now returns a `NamedTuple` with keys `logprior` and `loglikelihood`. If you want the log joint probability, which is what `getlogp` used to return, use `getlogjoint`.
- Correspondingly `setlogp!!` and `acclogp!!` should now be called with a `NamedTuple` with keys `logprior` and `loglikelihood`. The `acclogp!!` method with a single scalar value has been deprecated and falls back on `accloglikelihood!!`, and the single scalar version of `setlogp!!` has been removed. Corresponding setter/accumulator functions exist for the log prior as well.

## 0.36.0

**Breaking changes**
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down Expand Up @@ -68,6 +69,7 @@ MCMCChains = "6"
MacroTools = "0.5.6"
Mooncake = "0.4.95"
OrderedCollections = "1"
Printf = "1.10"
Random = "1.6"
Requires = "1"
Statistics = "1"
Expand Down
40 changes: 29 additions & 11 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ returned(::Model)

## Utilities

It is possible to manually increase (or decrease) the accumulated log density from within a model function.
It is possible to manually increase (or decrease) the accumulated log likelihood or prior from within a model function.

```@docs
@addlogprob!
@addloglikelihood!
@addlogprior!
```

Return values of the model function for a collection of samples can be obtained with [`returned(model, chain)`](@ref).
Expand Down Expand Up @@ -328,9 +330,9 @@ The following functions were used for sequential Monte Carlo methods.

```@docs
get_num_produce
set_num_produce!
increment_num_produce!
reset_num_produce!
set_num_produce!!
increment_num_produce!!
reset_num_produce!!
setorder!
set_retained_vns_del!
```
Expand All @@ -345,6 +347,22 @@ Base.empty!
SimpleVarInfo
```

### Accumulators

The subtypes of [`AbstractVarInfo`](@ref) store the cumulative log prior and log likelihood, and sometimes other variables that change during executing, in what are called accumulators.

```@docs
AbstractAccumulator
```

DynamicPPL provides the following default accumulators.

```@docs
LogPriorAccumulator
LogLikelihoodAccumulator
NumProduceAccumulator
```

### Common API

#### Accumulation of log-probabilities
Expand All @@ -353,6 +371,13 @@ SimpleVarInfo
getlogp
setlogp!!
acclogp!!
getlogjoint
getlogprior
setlogprior!!
acclogprior!!
getloglikelihood
setloglikelihood!!
accloglikelihood!!
resetlogp!!
```

Expand Down Expand Up @@ -427,9 +452,6 @@ Contexts are subtypes of `AbstractPPL.AbstractContext`.
```@docs
SamplingContext
DefaultContext
LikelihoodContext
PriorContext
MiniBatchContext
PrefixContext
ConditionContext
```
Expand Down Expand Up @@ -476,7 +498,3 @@ DynamicPPL.Experimental.is_suitable_varinfo
```@docs
tilde_assume
```

```@docs
tilde_observe
```
12 changes: 6 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,18 @@
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample
in `chain`, and return the resulting `Chains`.

The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
The `model` passed to `predict` is often different from the one used to generate `chain`.
Typically, the model from which `chain` originated treats certain variables as observed (i.e.,
data points), while the model you pass to `predict` may mark these same variables as missing
or unobserved. Calling `predict` then leverages the previously inferred parameter values to
simulate what new, unobserved data might look like, given your posterior beliefs.

For each parameter configuration in `chain`:
1. All random variables present in `chain` are fixed to their sampled values.
2. Any variables not included in `chain` are sampled from their prior distributions.

If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
the samples in `chain`. This is useful when you want to sample only new variables from the posterior
predictive distribution.

# Examples
Expand Down Expand Up @@ -124,7 +124,7 @@
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)),
)

return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo))
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogjoint(varinfo))

Check warning on line 127 in ext/DynamicPPLMCMCChainsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLMCMCChainsExt.jl#L127

Added line #L127 was not covered by tests
end

chain_result = reduce(
Expand Down
29 changes: 21 additions & 8 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Bijectors
using Compat
using Distributions
using OrderedCollections: OrderedCollections, OrderedDict
using Printf: Printf

using AbstractMCMC: AbstractMCMC
using ADTypes: ADTypes
Expand Down Expand Up @@ -46,17 +47,28 @@ import Base:
export AbstractVarInfo,
VarInfo,
SimpleVarInfo,
AbstractAccumulator,
LogLikelihoodAccumulator,
LogPriorAccumulator,
NumProduceAccumulator,
push!!,
empty!!,
subset,
getlogp,
getlogjoint,
getlogprior,
getloglikelihood,
setlogp!!,
setlogprior!!,
setloglikelihood!!,
acclogp!!,
acclogprior!!,
accloglikelihood!!,
resetlogp!!,
get_num_produce,
set_num_produce!,
reset_num_produce!,
increment_num_produce!,
set_num_produce!!,
reset_num_produce!!,
increment_num_produce!!,
set_retained_vns_del!,
is_flagged,
set_flag!,
Expand Down Expand Up @@ -92,15 +104,10 @@ export AbstractVarInfo,
# Contexts
SamplingContext,
DefaultContext,
LikelihoodContext,
PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
assume,
observe,
tilde_assume,
tilde_observe,
# Pseudo distributions
NamedDist,
NoDist,
Expand All @@ -120,6 +127,8 @@ export AbstractVarInfo,
to_submodel,
# Convenience macros
@addlogprob!,
@addlogprior!,
@addloglikelihood!,
@submodel,
value_iterator_from_chain,
check_model,
Expand All @@ -146,6 +155,9 @@ macro prob_str(str)
))
end

# TODO(mhauru) We should write down the list of methods that any subtype of AbstractVarInfo
# has to implement. Not sure what the full list is for parameters values, but for
# accumulators we only need `getaccs` and `setaccs!!`.
"""
AbstractVarInfo

Expand All @@ -166,6 +178,7 @@ include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varnamedvector.jl")
include("accumulators.jl")
include("abstract_varinfo.jl")
include("threadsafe.jl")
include("varinfo.jl")
Expand Down
Loading
Loading