Skip to content

Commit 7b01d25

Browse files
torfjeldegithub-actions[bot]yebaidevmotion
authored
Proper support for distributions with embedded support (#462)
* compat with new Bijectors.jl * bump compat bounds for Bijectors and make it a breaking change * remove mentioning of Exp and Identity in test_utils.jl * added mistakenly commented out tests * fixed test_utils * bump bijectors version * added no-op impls for reconstruct * added a bunch of convenience methods for working with Metadata instead of VarInfo * added usage of _inner_transform! in link, in addition to additional methods for linking and invlinking * updated getall to not assume we want all the values in metadata * added FIXME comment * fixed typo in comment * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * sligh simplification of the linking stuff * formatting * lower bound test compat entry for Tracker * move link-related functions to abstract_varinfo.jl and renamed methods to be more descriptive * fixed invlink!! for VarInfo * fixed link and invlink tests * added specialized mapreduce for (named)tuples to improve type-inference * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added missing docstring * added minor TODO comment for the future * added `link_transform` and `invlink_transform`, basically equivalent to `bijector` but allows us to separate the choices made in DPPL from those in Bijectors * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/utils.jl * added some docstrings * renamed link_and_reconstruct to the more accurate reconstruct_and_link * removed unnecessary definition of inlink_transform * fixed bug in newmetadata * removed mapreduce_tuple in favor of reduce and map * Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * introduce _logpdf_with_trans as a placeholder while we migrate away from the usage of this function and into invlink_and_pdf * reconstruct now takes into account the transformation to be used * replaced more references to bijector with link_transform * added docstring for invlink_with_logpdf * fixed bug in assume introduced hacky getval for SimpleVarInfo * added tests for linking * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * rename maybe_link_and_reconstruct to maybe_reconstruct_and_link * added reconstruct to the API docs * Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * removed unnecessary comment * removed _logpdf_with_trans in favour of just using Bijectors.jl's for now * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * added warning regarding overloading to link_transform and invlink_transform * added missing getval for ThreadSafeVarInfo * added a minor additional test to linking * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * reverted chagnes from previous commit * fixed usage of deprecated link * Update test/linking.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/DynamicPPL.jl * fixed tests * added copy to tonamedtuple to avoid mutating chain samples * improved testing for setval! and generated_quantities * bumped the version in turing tests * Apply suggestions from code review --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: David Widmann <[email protected]>
1 parent 5a729e2 commit 7b01d25

13 files changed

+331
-175
lines changed

docs/src/api.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ DynamicPPL.link!!
206206
DynamicPPL.invlink!!
207207
DynamicPPL.default_transformation
208208
DynamicPPL.maybe_invlink_before_eval!!
209-
```
209+
DynamicPPL.reconstruct
210+
```
210211

211212
#### Utils
212213

src/DynamicPPL.jl

-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ export AbstractVarInfo,
4343
push!!,
4444
empty!!,
4545
getlogp,
46-
resetlogp!,
4746
setlogp!!,
4847
acclogp!!,
4948
resetlogp!!,

src/abstract_varinfo.jl

+93
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,99 @@ variables `x` would return
553553
"""
554554
function tonamedtuple end
555555

556+
# TODO: Clean up all this linking stuff once and for all!
557+
"""
558+
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)
559+
560+
Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
561+
value is reconstructed to the correct type and shape according to `dist`.
562+
"""
563+
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
564+
x_recon = reconstruct(f, dist, x)
565+
return with_logabsdet_jacobian(f, x_recon)
566+
end
567+
568+
# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
569+
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
570+
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
571+
"""
572+
reconstruct_and_link(dist, val)
573+
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
574+
575+
Return linked `val` but reconstruct before linking, if necessary.
576+
577+
Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
578+
return a reconstructed value, i.e. a value of the same type and shape as expected
579+
by `dist`.
580+
581+
See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
582+
"""
583+
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
584+
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
585+
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
586+
return reconstruct_and_link(dist, val)
587+
end
588+
589+
"""
590+
invlink_and_reconstruct(dist, val)
591+
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
592+
593+
Return invlinked and reconstructed `val`.
594+
595+
See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
596+
"""
597+
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
598+
function invlink_and_reconstruct(dist, val)
599+
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
600+
end
601+
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
602+
return invlink_and_reconstruct(dist, val)
603+
end
604+
605+
"""
606+
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
607+
608+
Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
609+
"""
610+
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
611+
return if istrans(vi, vn)
612+
reconstruct_and_link(vi, vn, dist, val)
613+
else
614+
reconstruct(dist, val)
615+
end
616+
end
617+
618+
"""
619+
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
620+
621+
Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
622+
"""
623+
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
624+
return if istrans(vi, vn)
625+
invlink_and_reconstruct(vi, vn, dist, val)
626+
else
627+
reconstruct(dist, val)
628+
end
629+
end
630+
631+
"""
632+
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])
633+
634+
Invlink `x` and compute the logpdf under `dist` including correction from
635+
the invlink-transformation.
636+
637+
If `x` is not provided, `getval(vi, vn)` will be used.
638+
"""
639+
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
640+
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
641+
end
642+
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
643+
# NOTE: Will this cause type-instabilities or will union-splitting save us?
644+
f = istrans(vi, vn) ? invlink_transform(dist) : identity
645+
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
646+
return x, logpdf(dist, x) + logjac
647+
end
648+
556649
# Legacy code that is currently overloaded for the sake of simplicity.
557650
# TODO: Remove when possible.
558651
increment_num_produce!(::AbstractVarInfo) = nothing

src/context_implementations.jl

+21-9
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ end
194194

195195
# fallback without sampler
196196
function assume(dist::Distribution, vn::VarName, vi)
197-
r = vi[vn, dist]
198-
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
197+
r, logp = invlink_with_logpdf(vi, vn, dist)
198+
return r, logp, vi
199199
end
200200

201201
# SampleFromPrior and SampleFromUniform
@@ -211,7 +211,9 @@ function assume(
211211
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
212212
unset_flag!(vi, vn, "del")
213213
r = init(rng, dist, sampler)
214-
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
214+
BangBang.setindex!!(
215+
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
216+
)
215217
setorder!(vi, vn, get_num_produce(vi))
216218
else
217219
# Otherwise we just extract it.
@@ -220,15 +222,17 @@ function assume(
220222
else
221223
r = init(rng, dist, sampler)
222224
if istrans(vi)
223-
push!!(vi, vn, link(dist, r), dist, sampler)
225+
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
224226
# By default `push!!` sets the transformed flag to `false`.
225227
settrans!!(vi, true, vn)
226228
else
227229
push!!(vi, vn, r, dist, sampler)
228230
end
229231
end
230232

231-
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
233+
# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
234+
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
235+
return r, logpdf(dist, r) - logjac, vi
232236
end
233237

234238
# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
@@ -470,7 +474,11 @@ function get_and_set_val!(
470474
r = init(rng, dist, spl, n)
471475
for i in 1:n
472476
vn = vns[i]
473-
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
477+
setindex!!(
478+
vi,
479+
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
480+
vn,
481+
)
474482
setorder!(vi, vn, get_num_produce(vi))
475483
end
476484
else
@@ -508,13 +516,17 @@ function get_and_set_val!(
508516
for i in eachindex(vns)
509517
vn = vns[i]
510518
dist = dists isa AbstractArray ? dists[i] : dists
511-
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
519+
setindex!!(
520+
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
521+
)
512522
setorder!(vi, vn, get_num_produce(vi))
513523
end
514524
else
515525
# r = reshape(vi[vec(vns)], size(vns))
526+
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
527+
# and fix the lines below.
516528
r_raw = getindex_raw(vi, vec(vns))
517-
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
529+
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
518530
end
519531
else
520532
f = (vn, dist) -> init(rng, dist, spl)
@@ -525,7 +537,7 @@ function get_and_set_val!(
525537
# 2. Define an anonymous function which returns `nothing`, which
526538
# we then broadcast. This will allocate a vector of `nothing` though.
527539
if istrans(vi)
528-
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
540+
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
529541
# NOTE: Need to add the correction.
530542
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
531543
# `push!!` sets the trans-flag to `false` by default.

src/simple_varinfo.jl

+10-7
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ end
290290

291291
# `NamedTuple`
292292
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
293-
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
293+
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
294294
end
295295
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
296296
vals_linked = mapreduce(vcat, vns) do vn
@@ -329,6 +329,9 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
329329
return reconstruct(dist, vals, length(vns))
330330
end
331331

332+
# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
333+
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)
334+
332335
Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)
333336

334337
function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
@@ -426,7 +429,7 @@ function assume(
426429
)
427430
value = init(rng, dist, sampler)
428431
# Transform if we're working in unconstrained space.
429-
value_raw = maybe_link(vi, vn, dist, value)
432+
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
430433
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
431434
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
432435
end
@@ -444,9 +447,9 @@ function dot_assume(
444447

445448
# Transform if we're working in transformed space.
446449
value_raw = if dists isa Distribution
447-
maybe_link.((vi,), vns, (dists,), value)
450+
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
448451
else
449-
maybe_link.((vi,), vns, dists, value)
452+
maybe_reconstruct_and_link.((vi,), vns, dists, value)
450453
end
451454

452455
# Update `vi`
@@ -473,7 +476,7 @@ function dot_assume(
473476

474477
# Update `vi`.
475478
for (vn, val) in zip(vns, eachcol(value))
476-
val_linked = maybe_link(vi, vn, dist, val)
479+
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
477480
vi = BangBang.setindex!!(vi, val_linked, vn)
478481
end
479482

@@ -488,7 +491,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {
488491
nt_vals = map(keys(vi)) do vn
489492
val = vi[vn]
490493
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
491-
vals = map(Base.Fix1(getindex, vi), vns)
494+
vals = map(copy Base.Fix1(getindex, vi), vns)
492495
(vals, map(string, vns))
493496
end
494497

@@ -501,7 +504,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
501504
# Extract the leaf varnames and values.
502505
val = vi[vn]
503506
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
504-
vals = map(Base.Fix1(getindex, vi), vns)
507+
vals = map(copy Base.Fix1(getindex, vi), vns)
505508

506509
# Determine the corresponding symbol.
507510
sym = only(unique(map(getsym, vns)))

src/threadsafe.jl

+2
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,5 @@ end
178178

179179
istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
180180
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)
181+
182+
getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)

src/transforming.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ function tilde_assume(
1515

1616
# Only transform if `!isinverse` since `vi[vn, right]`
1717
# already performs the inverse transformation if it's transformed.
18-
r_transformed = isinverse ? r : bijector(right)(r)
18+
r_transformed = isinverse ? r : link_transform(right)(r)
1919
return r, lp, setindex!!(vi, r_transformed, vn)
2020
end
2121

@@ -27,7 +27,7 @@ function dot_tilde_assume(
2727
vi,
2828
) where {isinverse}
2929
r = getindex.((vi,), vns, (dist,))
30-
b = bijector(dist)
30+
b = link_transform(dist)
3131

3232
is_trans_uniques = unique(istrans.((vi,), vns))
3333
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
@@ -70,7 +70,7 @@ function dot_tilde_assume(
7070
@assert !isinverse "Trying to invlink non-transformed variables"
7171
end
7272

73-
b = bijector(dist)
73+
b = link_transform(dist)
7474
for (vn, ri) in zip(vns, eachcol(r))
7575
# Only transform if `!isinverse` since `vi[vn, right]`
7676
# already performs the inverse transformation if it's transformed.

src/utils.jl

+46-1
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,39 @@ function to_namedtuple_expr(syms, vals)
177177
return :(NamedTuple{$names_expr}($vals_expr))
178178
end
179179

180+
"""
181+
link_transform(dist)
182+
183+
Return the constrained-to-unconstrained bijector for distribution `dist`.
184+
185+
By default, this is just `Bijectors.bijector(dist)`.
186+
187+
!!! warning
188+
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
189+
hence that needs to be overloaded separately if the intention is
190+
to change behavior of an existing distribution.
191+
"""
192+
link_transform(dist) = bijector(dist)
193+
194+
"""
195+
invlink_transform(dist)
196+
197+
Return the unconstrained-to-constrained bijector for distribution `dist`.
198+
199+
By default, this is just `inverse(link_transform(dist))`.
200+
201+
!!! warning
202+
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
203+
hence that needs to be overloaded separately if the intention is
204+
to change behavior of an existing distribution.
205+
"""
206+
invlink_transform(dist) = inverse(link_transform(dist))
207+
180208
#####################################################
181209
# Helper functions for vectorize/reconstruct values #
182210
#####################################################
183211

212+
vectorize(d, r) = vec(r)
184213
vectorize(d::UnivariateDistribution, r::Real) = [r]
185214
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
186215
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
@@ -191,7 +220,23 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
191220
# otherwise we will have error for MatrixDistribution.
192221
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
193222
# support for some types related to matrices (like PDMat).
194-
reconstruct(d::UnivariateDistribution, val::Real) = val
223+
224+
"""
225+
reconstruct([f, ]dist, val)
226+
227+
Reconstruct `val` so that it's compatible with `dist`.
228+
229+
If `f` is also provided, the reconstruct value will be
230+
such that `f(reconstruct_val)` is compatible with `dist`.
231+
"""
232+
reconstruct(f, dist, val) = reconstruct(dist, val)
233+
234+
# No-op versions.
235+
reconstruct(::UnivariateDistribution, val::Real) = val
236+
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
237+
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
238+
# TODO: Implement no-op `reconstruct` for general array variates.
239+
195240
reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
196241
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
197242
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)

0 commit comments

Comments
 (0)