Skip to content

Commit 0947bd7

Browse files
torfjeldeyebai
andcommitted
Linearization/flattening of SimpleVarInfo (#417)
This PR introduces a couple of things, though these are related: 1. `unflatten(original[, spl], x)`: converts from a certain input `x`, usually a `Vector`, into an instance similar to `original`. - Effectively the same as the current constructor `VarInfo(varinfo_old, spl, x)`. - I looked into using ParameterHandling.jl for this but decided against it for a couple of reasons: - Seems overkill. - `unflatten`-equivalent is constructed as a closure, which means that we need to keep track of this returned method rather than just using a "template" `AbstractVarInfo` + construction of unflattening requires construction of the flatten representation + the way one specifies the types is a bit too opinionated (which causes some issues with certain AD-frameworks) + closures can have less desirable performance characteristics. - The current Turing.jl-codebase is easily adapted to this `unflatten` since it's really just a matter of replacing calls `VarInfo(varinfo_old, spl, x)` with `unflatten(varinfo_old, spl, x)`. A ParameterHandling.jl approach will require more work. 2. `link!!` and `invlink!!`, BangBang-versions of `link!` and `invlink!`, respectively, with some differences: - These take additional arguments which should always be sufficient to determine the transformation. These are: - `model` - `sampler` - `t::AbstractTransformation`. This sets us up for allowing alternative transformations to be used. As of right now, this only has an affect when calling `link!!` and `invlink!!`, _not_ when used inside of the tilde-pipeline. - Also adds the logabsdet-jacobian term to the `logp`, so that `getlogp(vi) ≠ getlogp(link!!(vi))` holds. This allows us to compute, say, `logjoint` by _first_ linking `vi` in a single pass, and then compute `logjoint(settrans!(vi, NoTransformation()), θ_constrained)`. Such a pattern, in particular if the transformation has been specified by the user themselves, will usually have much better performance than the `logpdf_with_trans(..., true)` within the tilde-callstack. 3. `make_default_varinfo(rng, model, sampler)` which allows one to overload on a, say, per-model or model-sampler-combination basis to specify which implementation of `AbstractVarInfo` to use. - Not entirely happy with this approach 😕 EDIT: This should be merged _after_ #420 Co-authored-by: Hong Ge <[email protected]>
1 parent 0457785 commit 0947bd7

19 files changed

+1383
-463
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.20.2"
3+
version = "0.21.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

docs/make.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ using DynamicPPL
33
using DynamicPPL: AbstractPPL
44

55
# Doctest setup
6-
DocMeta.setdocmeta!(DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true)
6+
DocMeta.setdocmeta!(
7+
DynamicPPL, :DocTestSetup, :(using DynamicPPL, Distributions); recursive=true
8+
)
79

810
makedocs(;
911
sitename="DynamicPPL",

docs/src/api.md

+33-2
Original file line numberDiff line numberDiff line change
@@ -154,23 +154,56 @@ AbstractVarInfo
154154

155155
### Common API
156156

157+
#### Accumulation of log-probabilities
158+
157159
```@docs
158160
getlogp
159161
setlogp!!
160162
acclogp!!
161163
resetlogp!!
162164
```
163165

166+
#### Variables and their realizations
167+
164168
```@docs
169+
keys
165170
getindex
171+
DynamicPPL.getindex_raw
166172
push!!
167173
empty!!
174+
isempty
168175
```
169176

170177
```@docs
171178
values_as
172179
```
173180

181+
#### Transformations
182+
183+
```@docs
184+
DynamicPPL.AbstractTransformation
185+
DynamicPPL.NoTransformation
186+
DynamicPPL.DynamicTransformation
187+
DynamicPPL.StaticTransformation
188+
```
189+
190+
```@docs
191+
DynamicPPL.istrans
192+
DynamicPPL.settrans!!
193+
DynamicPPL.transformation
194+
DynamicPPL.link!!
195+
DynamicPPL.invlink!!
196+
DynamicPPL.default_transformation
197+
DynamicPPL.maybe_invlink_before_eval!!
198+
```
199+
200+
#### Utils
201+
202+
```@docs
203+
DynamicPPL.unflatten
204+
DynamicPPL.tonamedtuple
205+
```
206+
174207
#### `SimpleVarInfo`
175208

176209
```@docs
@@ -189,10 +222,8 @@ TypedVarInfo
189222
One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.
190223

191224
```@docs
192-
tonamedtuple
193225
link!
194226
invlink!
195-
istrans
196227
```
197228

198229
```@docs

src/DynamicPPL.jl

+12-3
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ export AbstractVarInfo,
5959
setorder!,
6060
istrans,
6161
link!,
62+
link!!,
6263
invlink!,
64+
invlink!!,
6365
tonamedtuple,
6466
values_as,
6567
# VarName (reexport from AbstractPPL)
@@ -125,27 +127,33 @@ export loglikelihood
125127
# Used here and overloaded in Turing
126128
function getspace end
127129

128-
# Necessary forward declarations
129130
"""
130131
AbstractVarInfo
131132
132133
Abstract supertype for data structures that capture random variables when executing a
133134
probabilistic model and accumulate log densities such as the log likelihood or the
134135
log joint probability of the model.
135136
136-
See also: [`VarInfo`](@ref)
137+
See also: [`VarInfo`](@ref), [`SimpleVarInfo`](@ref).
137138
"""
138139
abstract type AbstractVarInfo <: AbstractModelTrace end
139140

141+
const LEGACY_WARNING = """
142+
!!! warning
143+
This method is considered legacy, and is likely to be deprecated in the future.
144+
"""
145+
146+
# Necessary forward declarations
140147
include("utils.jl")
141148
include("selector.jl")
142149
include("model.jl")
143150
include("sampler.jl")
144151
include("varname.jl")
145152
include("distribution_wrappers.jl")
146153
include("contexts.jl")
147-
include("varinfo.jl")
154+
include("abstract_varinfo.jl")
148155
include("threadsafe.jl")
156+
include("varinfo.jl")
149157
include("simple_varinfo.jl")
150158
include("context_implementations.jl")
151159
include("compiler.jl")
@@ -154,5 +162,6 @@ include("compat/ad.jl")
154162
include("loglikelihoods.jl")
155163
include("submodel_macro.jl")
156164
include("test_utils.jl")
165+
include("transforming.jl")
157166

158167
end # module

0 commit comments

Comments
 (0)