Skip to content

Commit 07cc442

Browse files
torfjeldeyebai
authored and
Alexey Stukalov
committed
Linearization/flattening of SimpleVarInfo (TuringLang#417)
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ TuringLang#420 Co-authored-by: Hong Ge <[email protected]>
1 parent a3f12dc commit 07cc442

19 files changed

+1383
-463
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.20.2"
3+
version = "0.21.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/make.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ using DynamicPPL
33
using DynamicPPL: AbstractPPL
44

55
# Doctest setup
6-
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
6+
DocMeta.setdocmeta!(
7+
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
8+
)
79

810
makedocs(;
911
sitename="DynamicPPL",

docs/src/api.md

+33-2
Original file line numberDiff line numberDiff line change
@@ -156,23 +156,56 @@ AbstractVarInfo
156156

157157
### Common API
158158

159+
#### Accumulation of log-probabilities
160+
159161
```@docs
160162
getlogp
161163
setlogp!!
162164
acclogp!!
163165
resetlogp!!
164166
```
165167

168+
#### Variables and their realizations
169+
166170
```@docs
171+
keys
167172
getindex
173+
DynamicPPL.getindex_raw
168174
push!!
169175
empty!!
176+
isempty
170177
```
171178

172179
```@docs
173180
values_as
174181
```
175182

183+
#### Transformations
184+
185+
```@docs
186+
DynamicPPL.AbstractTransformation
187+
DynamicPPL.NoTransformation
188+
DynamicPPL.DynamicTransformation
189+
DynamicPPL.StaticTransformation
190+
```
191+
192+
```@docs
193+
DynamicPPL.istrans
194+
DynamicPPL.settrans!!
195+
DynamicPPL.transformation
196+
DynamicPPL.link!!
197+
DynamicPPL.invlink!!
198+
DynamicPPL.default_transformation
199+
DynamicPPL.maybe_invlink_before_eval!!
200+
```
201+
202+
#### Utils
203+
204+
```@docs
205+
DynamicPPL.unflatten
206+
DynamicPPL.tonamedtuple
207+
```
208+
176209
#### `SimpleVarInfo`
177210

178211
```@docs
@@ -191,10 +224,8 @@ TypedVarInfo
191224
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
192225

193226
```@docs
194-
tonamedtuple
195227
link!
196228
invlink!
197-
istrans
198229
```
199230

200231
```@docs

src/DynamicPPL.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ export AbstractVarInfo,
5959
setorder!,
6060
istrans,
6161
link!,
62+
link!!,
6263
invlink!,
64+
invlink!!,
6365
tonamedtuple,
6466
values_as,
6567
# VarName (reexport from AbstractPPL)
@@ -126,27 +128,33 @@ export loglikelihood
126128
# Used here and overloaded in Turing
127129
function getspace end
128130

129-
# Necessary forward declarations
130131
"""
131132
AbstractVarInfo
132133
133134
Abstract supertype for data structures that capture random variables when executing a
134135
probabilistic model and accumulate log densities such as the log likelihood or the
135136
log joint probability of the model.
136137
137-
See also: [`VarInfo`](@ref)
138+
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
138139
"""
139140
abstract type AbstractVarInfo <: AbstractModelTrace end
140141

142+
const LEGACY_WARNING = """
143+
!!! warning
144+
This method is considered legacy, and is likely to be deprecated in the future.
145+
"""
146+
147+
# Necessary forward declarations
141148
include("utils.jl")
142149
include("selector.jl")
143150
include("model.jl")
144151
include("sampler.jl")
145152
include("varname.jl")
146153
include("distribution_wrappers.jl")
147154
include("contexts.jl")
148-
include("varinfo.jl")
155+
include("abstract_varinfo.jl")
149156
include("threadsafe.jl")
157+
include("varinfo.jl")
150158
include("simple_varinfo.jl")
151159
include("context_implementations.jl")
152160
include("compiler.jl")
@@ -155,5 +163,6 @@ include("compat/ad.jl")
155163
include("loglikelihoods.jl")
156164
include("submodel_macro.jl")
157165
include("test_utils.jl")
166+
include("transforming.jl")
158167

159168
end # module

0 commit comments

Comments
 (0)