diff --git a/docs/src/api.md b/docs/src/api.md index 0e4012e02..2dfda9119 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -206,8 +206,7 @@ DynamicPPL.link!! DynamicPPL.invlink!! DynamicPPL.default_transformation DynamicPPL.maybe_invlink_before_eval!! -DynamicPPL.reconstruct -``` +``` #### Utils diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 8904cfe81..594084d66 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -43,6 +43,7 @@ export AbstractVarInfo, push!!, empty!!, getlogp, + resetlogp!, setlogp!!, acclogp!!, resetlogp!!, diff --git a/src/abstract_varinfo.jl b/src/abstract_varinfo.jl index 116890d7b..acd51e288 100644 --- a/src/abstract_varinfo.jl +++ b/src/abstract_varinfo.jl @@ -553,99 +553,6 @@ variables `x` would return """ function tonamedtuple end -# TODO: Clean up all this linking stuff once and for all! -""" - with_logabsdet_jacobian_and_reconstruct([f, ]dist, x) - -Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting -value is reconstructed to the correct type and shape according to `dist`. -""" -function with_logabsdet_jacobian_and_reconstruct(f, dist, x) - x_recon = reconstruct(f, dist, x) - return with_logabsdet_jacobian(f, x_recon) -end - -# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can -# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden. -# NOTE: `reconstruct` is no-op if `val` is already of correct shape. -""" - reconstruct_and_link(dist, val) - reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val) - -Return linked `val` but reconstruct before linking, if necessary. - -Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily -return a reconstructed value, i.e. a value of the same type and shape as expected -by `dist`. - -See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref). -""" -reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val)) -reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val) -function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val) - return reconstruct_and_link(dist, val) -end - -""" - invlink_and_reconstruct(dist, val) - invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) - -Return invlinked and reconstructed `val`. - -See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref). -""" -invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val)) -function invlink_and_reconstruct(dist, val) - return invlink_and_reconstruct(invlink_transform(dist), dist, val) -end -function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val) - return invlink_and_reconstruct(dist, val) -end - -""" - maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) - -Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`. -""" -function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val) - return if istrans(vi, vn) - reconstruct_and_link(vi, vn, dist, val) - else - reconstruct(dist, val) - end -end - -""" - maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) - -Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`. -""" -function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val) - return if istrans(vi, vn) - invlink_and_reconstruct(vi, vn, dist, val) - else - reconstruct(dist, val) - end -end - -""" - invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x]) - -Invlink `x` and compute the logpdf under `dist` including correction from -the invlink-transformation. - -If `x` is not provided, `getval(vi, vn)` will be used. -""" -function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist) - return invlink_with_logpdf(vi, vn, dist, getval(vi, vn)) -end -function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y) - # NOTE: Will this cause type-instabilities or will union-splitting save us? - f = istrans(vi, vn) ? invlink_transform(dist) : identity - x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y) - return x, logpdf(dist, x) + logjac -end - # Legacy code that is currently overloaded for the sake of simplicity. # TODO: Remove when possible. increment_num_produce!(::AbstractVarInfo) = nothing diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 1f0641007..1078a0e18 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -194,8 +194,8 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) - r, logp = invlink_with_logpdf(vi, vn, dist) - return r, logp, vi + r = vi[vn, dist] + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # SampleFromPrior and SampleFromUniform @@ -211,9 +211,7 @@ function assume( if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") unset_flag!(vi, vn, "del") r = init(rng, dist, sampler) - BangBang.setindex!!( - vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn - ) + BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn) setorder!(vi, vn, get_num_produce(vi)) else # Otherwise we just extract it. @@ -222,7 +220,7 @@ function assume( else r = init(rng, dist, sampler) if istrans(vi) - push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler) + push!!(vi, vn, link(dist, r), dist, sampler) # By default `push!!` sets the transformed flag to `false`. settrans!!(vi, true, vn) else @@ -230,9 +228,7 @@ function assume( end end - # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct. - logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r) - return r, logpdf(dist, r) - logjac, vi + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) @@ -474,11 +470,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - setindex!!( - vi, - vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])), - vn, - ) + setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn) setorder!(vi, vn, get_num_produce(vi)) end else @@ -516,17 +508,13 @@ function get_and_set_val!( for i in eachindex(vns) vn = vns[i] dist = dists isa AbstractArray ? dists[i] : dists - setindex!!( - vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn - ) + setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn) setorder!(vi, vn, get_num_produce(vi)) end else # r = reshape(vi[vec(vns)], size(vns)) - # FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)` - # and fix the lines below. r_raw = getindex_raw(vi, vec(vns)) - r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns))) + r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns))) end else f = (vn, dist) -> init(rng, dist, spl) @@ -537,7 +525,7 @@ function get_and_set_val!( # 2. Define an anonymous function which returns `nothing`, which # we then broadcast. This will allocate a vector of `nothing` though. if istrans(vi) - push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,)) + push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,)) # NOTE: Need to add the correction. acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r))) # `push!!` sets the trans-flag to `false` by default. diff --git a/src/simple_varinfo.jl b/src/simple_varinfo.jl index 68b3d0ae2..a445bf87a 100644 --- a/src/simple_varinfo.jl +++ b/src/simple_varinfo.jl @@ -290,7 +290,7 @@ end # `NamedTuple` function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution) - return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn)) + return maybe_invlink(vi, vn, dist, getindex(vi, vn)) end function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution) vals_linked = mapreduce(vcat, vns) do vn @@ -329,9 +329,6 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut return reconstruct(dist, vals, length(vns)) end -# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`. -getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn) - Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn) function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) @@ -429,7 +426,7 @@ function assume( ) value = init(rng, dist, sampler) # Transform if we're working in unconstrained space. - value_raw = maybe_reconstruct_and_link(vi, vn, dist, value) + value_raw = maybe_link(vi, vn, dist, value) vi = BangBang.push!!(vi, vn, value_raw, dist, sampler) return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi end @@ -447,9 +444,9 @@ function dot_assume( # Transform if we're working in transformed space. value_raw = if dists isa Distribution - maybe_reconstruct_and_link.((vi,), vns, (dists,), value) + maybe_link.((vi,), vns, (dists,), value) else - maybe_reconstruct_and_link.((vi,), vns, dists, value) + maybe_link.((vi,), vns, dists, value) end # Update `vi` @@ -476,7 +473,7 @@ function dot_assume( # Update `vi`. for (vn, val) in zip(vns, eachcol(value)) - val_linked = maybe_reconstruct_and_link(vi, vn, dist, val) + val_linked = maybe_link(vi, vn, dist, val) vi = BangBang.setindex!!(vi, val_linked, vn) end @@ -491,7 +488,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where { nt_vals = map(keys(vi)) do vn val = vi[vn] vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) - vals = map(copy ∘ Base.Fix1(getindex, vi), vns) + vals = map(Base.Fix1(getindex, vi), vns) (vals, map(string, vns)) end @@ -504,7 +501,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) # Extract the leaf varnames and values. val = vi[vn] vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) - vals = map(copy ∘ Base.Fix1(getindex, vi), vns) + vals = map(Base.Fix1(getindex, vi), vns) # Determine the corresponding symbol. sym = only(unique(map(getsym, vns))) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index dc8720e0a..85ad0e23e 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -178,5 +178,3 @@ end istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn) istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns) - -getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn) diff --git a/src/transforming.jl b/src/transforming.jl index bb8abddd6..f4b50b057 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -15,7 +15,7 @@ function tilde_assume( # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. - r_transformed = isinverse ? r : link_transform(right)(r) + r_transformed = isinverse ? r : bijector(right)(r) return r, lp, setindex!!(vi, r_transformed, vn) end @@ -27,7 +27,7 @@ function dot_tilde_assume( vi, ) where {isinverse} r = getindex.((vi,), vns, (dist,)) - b = link_transform(dist) + b = bijector(dist) is_trans_uniques = unique(istrans.((vi,), vns)) @assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables" @@ -70,7 +70,7 @@ function dot_tilde_assume( @assert !isinverse "Trying to invlink non-transformed variables" end - b = link_transform(dist) + b = bijector(dist) for (vn, ri) in zip(vns, eachcol(r)) # Only transform if `!isinverse` since `vi[vn, right]` # already performs the inverse transformation if it's transformed. diff --git a/src/utils.jl b/src/utils.jl index 8f3a0d101..f0bba2071 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -177,39 +177,10 @@ function to_namedtuple_expr(syms, vals) return :(NamedTuple{$names_expr}($vals_expr)) end -""" - link_transform(dist) - -Return the constrained-to-unconstrained bijector for distribution `dist`. - -By default, this is just `Bijectors.bijector(dist)`. - -!!! warning - Note that currently this is not used by `Bijectors.logpdf_with_trans`, - hence that needs to be overloaded separately if the intention is - to change behavior of an existing distribution. -""" -link_transform(dist) = bijector(dist) - -""" - invlink_transform(dist) - -Return the unconstrained-to-constrained bijector for distribution `dist`. - -By default, this is just `inverse(link_transform(dist))`. - -!!! warning - Note that currently this is not used by `Bijectors.logpdf_with_trans`, - hence that needs to be overloaded separately if the intention is - to change behavior of an existing distribution. -""" -invlink_transform(dist) = inverse(link_transform(dist)) - ##################################################### # Helper functions for vectorize/reconstruct values # ##################################################### -vectorize(d, r) = vec(r) vectorize(d::UnivariateDistribution, r::Real) = [r] vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r) vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) @@ -220,23 +191,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). - -""" - reconstruct([f, ]dist, val) - -Reconstruct `val` so that it's compatible with `dist`. - -If `f` is also provided, the reconstruct value will be -such that `f(reconstruct_val)` is compatible with `dist`. -""" -reconstruct(f, dist, val) = reconstruct(dist, val) - -# No-op versions. -reconstruct(::UnivariateDistribution, val::Real) = val -reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val) -reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val) -# TODO: Implement no-op `reconstruct` for general array variates. - +reconstruct(d::UnivariateDistribution, val::Real) = val reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) reconstruct(::Tuple{}, val::AbstractVector) = val[1] reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) diff --git a/src/varinfo.jl b/src/varinfo.jl index a30c9ea24..17df5a97e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -155,7 +155,7 @@ end for f in names mdf = :(metadata.$f) if inspace(f, space) || length(space) == 0 - len = :(sum(length, $mdf.ranges)) + len = :(length($mdf.vals)) push!( exprs, :( @@ -271,24 +271,14 @@ getmetadata(vi::TypedVarInfo, vn::VarName) = getfield(vi.metadata, getsym(vn)) Return the index of `vn` in the metadata of `vi` corresponding to `vn`. """ -getidx(vi::VarInfo, vn::VarName) = getidx(getmetadata(vi, vn), vn) -getidx(md::Metadata, vn::VarName) = md.idcs[vn] +getidx(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).idcs[vn] """ getrange(vi::VarInfo, vn::VarName) Return the index range of `vn` in the metadata of `vi`. """ -getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) -getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] - -""" - setrange!(vi::VarInfo, vn::VarName, range) - -Set the index range of `vn` in the metadata of `vi` to `range`. -""" -setrange!(vi::VarInfo, vn::VarName, range) = setrange!(getmetadata(vi, vn), vn, range) -setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range +getrange(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).ranges[getidx(vi, vn)] """ getranges(vi::VarInfo, vns::Vector{<:VarName}) @@ -304,8 +294,7 @@ end Return the distribution from which `vn` was sampled in `vi`. """ -getdist(vi::VarInfo, vn::VarName) = getdist(getmetadata(vi, vn), vn) -getdist(md::Metadata, vn::VarName) = md.dists[getidx(md, vn)] +getdist(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).dists[getidx(vi, vn)] """ getval(vi::VarInfo, vn::VarName) @@ -314,8 +303,7 @@ Return the value(s) of `vn`. The values may or may not be transformed to Euclidean space. """ -getval(vi::VarInfo, vn::VarName) = getval(getmetadata(vi, vn), vn) -getval(md::Metadata, vn::VarName) = view(md.vals, getrange(md, vn)) +getval(vi::VarInfo, vn::VarName) = view(getmetadata(vi, vn).vals, getrange(vi, vn)) """ setval!(vi::VarInfo, val, vn::VarName) @@ -324,8 +312,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`. The values may or may not be transformed to Euclidean space. """ -setval!(vi::VarInfo, val, vn::VarName) = setval!(getmetadata(vi, vn), val, vn) -setval!(md::Metadata, val, vn::VarName) = md.vals[getrange(md, vn)] = [val;] +setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;] """ getval(vi::VarInfo, vns::Vector{<:VarName}) @@ -334,7 +321,9 @@ Return the value(s) of `vns`. The values may or may not be transformed to Euclidean space. """ -getval(vi::VarInfo, vns::Vector{<:VarName}) = mapreduce(Base.Fix1(getval, vi), vcat, vns) +function getval(vi::VarInfo, vns::Vector{<:VarName}) + return mapreduce(vn -> getval(vi, vn), vcat, vns) +end """ getall(vi::VarInfo) @@ -343,12 +332,14 @@ Return the values of all the variables in `vi`. The values may or may not be transformed to Euclidean space. """ -getall(vi::UntypedVarInfo) = getall(vi.metadata) -# NOTE: `mapreduce` over `NamedTuple` results in worse type-inference. -# See for example https://github.com/JuliaLang/julia/pull/46381. -getall(vi::TypedVarInfo) = reduce(vcat, map(getall, vi.metadata)) -function getall(md::Metadata) - return mapreduce(Base.Fix1(getval, md), vcat, md.vns; init=similar(md.vals, 0)) +getall(vi::UntypedVarInfo) = vi.metadata.vals +getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) +@generated function _getall(metadata::NamedTuple{names}) where {names} + exprs = [] + for f in names + push!(exprs, :(metadata.$f.vals)) + end + return :($(exprs...),) end """ @@ -748,13 +739,19 @@ function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val) ) return _link!(vi, spl, spaceval) end -function _link!(vi::UntypedVarInfo, spl::AbstractSampler) +function _link!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) for vn in vns + @debug "X -> ℝ for $(vn)..." dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, link_transform(dist)) + # TODO: Use inplace versions to avoid allocations + b = bijector(dist) + x = reconstruct(dist, getval(vi, vn)) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @@ -781,8 +778,13 @@ end if ~istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns + @debug "X -> R for $(vn)..." dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, link_transform(dist)) + x = reconstruct(dist, getval(vi, vn)) + b = bijector(dist) + y, logjac = with_logabsdet_jacobian(b, x) + setval!(vi, vectorize(dist, y), vn) + acclogp!!(vi, -logjac) settrans!!(vi, true, vn) end else @@ -837,8 +839,13 @@ function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns + @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, invlink_transform(dist)) + y = reconstruct(dist, getval(vi, vn)) + b = inverse(bijector(dist)) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else @@ -865,8 +872,13 @@ end if istrans(vi, f_vns[1]) # Iterate over all `f_vns` and transform for vn in f_vns + @debug "ℝ -> X for $(vn)..." dist = getdist(vi, vn) - _inner_transform!(vi, vn, dist, invlink_transform(dist)) + y = reconstruct(dist, getval(vi, vn)) + b = inverse(bijector(dist)) + x, logjac = with_logabsdet_jacobian(b, y) + setval!(vi, vectorize(dist, x), vn) + acclogp!!(vi, -logjac) settrans!!(vi, false, vn) end else @@ -879,20 +891,10 @@ end return expr end -function _inner_transform!(vi::VarInfo, vn::VarName, dist, f) - @debug "X -> ℝ for $(vn)..." - # TODO: Use inplace versions to avoid allocations - y, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, getval(vi, vn)) - yvec = vectorize(dist, y) - # Determine the new range. - start = first(getrange(vi, vn)) - # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(vi, vn, start:(start + length(yvec) - 1)) - # Set the new value. - setval!(vi, yvec, vn) - acclogp!!(vi, -logjac) - return vi -end +link(vi, vn, dist, val) = Bijectors.link(dist, val) +invlink(vi, vn, dist, val) = Bijectors.invlink(dist, val) +maybe_link(vi, vn, dist, val) = istrans(vi, vn) ? link(vi, vn, dist, val) : val +maybe_invlink(vi, vn, dist, val) = istrans(vi, vn) ? invlink(vi, vn, dist, val) : val """ islinked(vi::VarInfo, spl::Union{Sampler, SampleFromPrior}) @@ -925,8 +927,8 @@ end getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) function getindex(vi::VarInfo, vn::VarName, dist::Distribution) @assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo" - val = getval(vi, vn) - return maybe_invlink_and_reconstruct(vi, vn, dist, val) + val = getindex_raw(vi, vn, dist) + return maybe_invlink(vi, vn, dist, val) end function getindex(vi::VarInfo, vns::Vector{<:VarName}) # FIXME(torfjelde): Using `getdist(vi, first(vns))` won't be correct in cases @@ -1027,20 +1029,19 @@ end return expr end -# TODO: Remove this completely. -tonamedtuple(varinfo::VarInfo) = tonamedtuple(varinfo.metadata, varinfo) -function tonamedtuple(metadata::NamedTuple{names}, varinfo::VarInfo) where {names} - length(names) === 0 && return NamedTuple() - - vals_tuple = map(values(metadata)) do x - # NOTE: `tonamedtuple` is really only used in Turing.jl to convert to - # a "transition". This means that we really don't mutations of the values - # in `varinfo` to propoagate the previous samples. Hence we `copy.` - vals = map(copy ∘ Base.Fix1(getindex, varinfo), x.vns) - return vals, map(string, x.vns) +function tonamedtuple(vi::VarInfo) + return tonamedtuple(vi.metadata, vi) +end +@generated function tonamedtuple(metadata::NamedTuple{names}, vi::VarInfo) where {names} + length(names) === 0 && return :(NamedTuple()) + expr = Expr(:tuple) + map(names) do f + push!( + expr.args, + Expr(:(=), f, :(getindex.(Ref(vi), metadata.$f.vns), string.(metadata.$f.vns))), + ) end - - return NamedTuple{names}(vals_tuple) + return expr end @inline function findvns(vi, f_vns) diff --git a/test/linking.jl b/test/linking.jl deleted file mode 100644 index f81895788..000000000 --- a/test/linking.jl +++ /dev/null @@ -1,86 +0,0 @@ -using Bijectors - -# Simple transformations which alters the "dimension" of the variable. -struct TrilToVec{S} - size::S -end - -struct TrilFromVec{S} - size::S -end - -Bijectors.inverse(f::TrilToVec) = TrilFromVec(f.size) -Bijectors.inverse(f::TrilFromVec) = TrilToVec(f.size) - -function (v::TrilToVec)(x) - mask = tril(trues(v.size)) - return vec(x[mask]) -end -function (v::TrilFromVec)(y) - mask = tril(trues(v.size)) - x = similar(y, v.size) - x[mask] .= y - return LowerTriangular(x) -end - -# Just some dummy values so we can make sure that the log-prob computation -# has been altered correctly. -Bijectors.with_logabsdet_jacobian(f::TrilToVec, x) = (f(x), log(eltype(x)(2))) -Bijectors.with_logabsdet_jacobian(f::TrilFromVec, x) = (f(x), -eltype(x)(log(2))) - -# Dummy example. -struct MyMatrixDistribution <: ContinuousMatrixDistribution - dim::Int -end - -Base.size(d::MyMatrixDistribution) = (d.dim, d.dim) -function Distributions._rand!( - rng::AbstractRNG, d::MyMatrixDistribution, x::AbstractMatrix{<:Real} -) - return randn!(rng, x) -end -function Distributions._logpdf(::MyMatrixDistribution, x::AbstractMatrix{<:Real}) - return -sum(abs2, LowerTriangular(x)) / 2 -end - -# Skip reconstruction in the inverse-map since it's no longer needed. -DynamicPPL.reconstruct(::TrilFromVec, ::MyMatrixDistribution, x::AbstractVector{<:Real}) = x - -# Specify the link-transform to use. -Bijectors.bijector(dist::MyMatrixDistribution) = TrilToVec((dist.dim, dist.dim)) -function Bijectors.logpdf_with_trans(dist::MyMatrixDistribution, x, istrans::Bool) - lp = logpdf(dist, x) - if istrans - lp = lp - logabsdetjac(bijector(dist), x) - end - - return lp -end - -@testset "Linking" begin - # Just making sure the transformations are okay. - x = randn(3, 3) - f = TrilToVec((3, 3)) - f_inv = inverse(f) - y = f(x) - @test y isa AbstractVector - @test f_inv(f(x)) == LowerTriangular(x) - - # Within a model. - dist = MyMatrixDistribution(3) - @model demo() = m ~ dist - model = demo() - - vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(m),)) - @testset "$(short_varinfo_name(vi))" for vi in vis - # Evaluate once to ensure we have `logp` value. - vi = last(DynamicPPL.evaluate!!(model, vi, DefaultContext())) - vi_linked = DynamicPPL.link!!(deepcopy(vi), model) - # Difference should just be the log-absdet-jacobian "correction". - @test DynamicPPL.getlogp(vi) - DynamicPPL.getlogp(vi_linked) ≈ log(2) - @test vi_linked[@varname(m), dist] == LowerTriangular(vi[@varname(m), dist]) - # Linked one should be working with a lower-dimensional representation. - @test length(vi_linked[:]) < length(vi[:]) - @test length(vi_linked[:]) == 3 - end -end diff --git a/test/model.jl b/test/model.jl index d78133f5e..7fb8bcf0b 100644 --- a/test/model.jl +++ b/test/model.jl @@ -116,7 +116,7 @@ end model = DynamicPPL.TestUtils.demo_dynamic_constraint() vi = VarInfo(model) spl = SampleFromPrior() - link!!(vi, spl, model) + link!(vi, spl) for i in 1:10 # Sample with large variations. diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9a18c439d..a5b57f5f6 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -64,6 +64,7 @@ @testset "$(typeof(vi))" for vi in ( SimpleVarInfo(Dict()), SimpleVarInfo(values_constrained), VarInfo(model) ) + vi = SimpleVarInfo(values_constrained) for vn in DynamicPPL.TestUtils.varnames(model) vi = DynamicPPL.setindex!!(vi, get(values_constrained, vn), vn) end diff --git a/test/turing/model.jl b/test/turing/model.jl index fcbdd88a3..e27b177eb 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -1,12 +1,95 @@ @testset "model.jl" begin @testset "setval! & generated_quantities" begin - @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS - chain = sample(model, Prior(), 10) - # A simple way of checking that the computation is determinstic: run twice and compare. - res1 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - res2 = generated_quantities(model, MCMCChains.get_sections(chain, :parameters)) - @test all(res1 .== res2) - test_setval!(model, MCMCChains.get_sections(chain, :parameters)) + @model function demo1(xs, ::Type{TV}=Vector{Float64}) where {TV} + m = TV(undef, 2) + for i in 1:2 + m[i] ~ Normal(0, 1) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.0) + end + + return (m,) end + + @model function demo2(xs) + m ~ MvNormal(zeros(2), I) + + for i in eachindex(xs) + xs[i] ~ Normal(m[1], 1.0) + end + + return (m,) + end + + xs = randn(3) + model1 = demo1(xs) + model2 = demo2(xs) + + chain1 = sample(model1, MH(), 100) + chain2 = sample(model2, MH(), 100) + + res11 = generated_quantities(model1, MCMCChains.get_sections(chain1, :parameters)) + res21 = generated_quantities(model2, MCMCChains.get_sections(chain1, :parameters)) + + res12 = generated_quantities(model1, MCMCChains.get_sections(chain2, :parameters)) + res22 = generated_quantities(model2, MCMCChains.get_sections(chain2, :parameters)) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res11 .== res21) + @test all(res12 .== res22) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res12[1:(end - 1)] .!= res12[2:end]) + + test_setval!(model1, MCMCChains.get_sections(chain1, :parameters)) + test_setval!(model2, MCMCChains.get_sections(chain2, :parameters)) + + # Next level + @model function demo3(xs, ::Type{TV}=Vector{Float64}) where {TV} + m = Vector{TV}(undef, 2) + for i in 1:length(m) + m[i] ~ MvNormal(zeros(2), I) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.0) + end + + return (m,) + end + + @model function demo4(xs, ::Type{TV}=Vector{Vector{Float64}}) where {TV} + m = TV(undef, 2) + for i in 1:length(m) + m[i] ~ MvNormal(zeros(2), I) + end + + for i in eachindex(xs) + xs[i] ~ Normal(m[1][1], 1.0) + end + + return (m,) + end + + model3 = demo3(xs) + model4 = demo4(xs) + + chain3 = sample(model3, MH(), 100) + chain4 = sample(model4, MH(), 100) + + res33 = generated_quantities(model3, MCMCChains.get_sections(chain3, :parameters)) + res43 = generated_quantities(model4, MCMCChains.get_sections(chain3, :parameters)) + + res34 = generated_quantities(model3, MCMCChains.get_sections(chain4, :parameters)) + res44 = generated_quantities(model4, MCMCChains.get_sections(chain4, :parameters)) + + # Check that the two different models produce the same values for + # the same chains. + @test all(res33 .== res43) + @test all(res34 .== res44) + # Ensure that they're not all the same (some can be, because rejected samples) + @test any(res34[1:(end - 1)] .!= res34[2:end]) end end