Skip to content

Revert "Proper support for distributions with embedded support" #486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ DynamicPPL.link!!
DynamicPPL.invlink!!
DynamicPPL.default_transformation
DynamicPPL.maybe_invlink_before_eval!!
DynamicPPL.reconstruct
```
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
```
```


#### Utils

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export AbstractVarInfo,
push!!,
empty!!,
getlogp,
resetlogp!,
setlogp!!,
acclogp!!,
resetlogp!!,
Expand Down
93 changes: 0 additions & 93 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -553,99 +553,6 @@ variables `x` would return
"""
function tonamedtuple end

# TODO: Clean up all this linking stuff once and for all!
"""
with_logabsdet_jacobian_and_reconstruct([f, ]dist, x)

Like `Bijectors.with_logabsdet_jacobian(f, x)`, but also ensures the resulting
value is reconstructed to the correct type and shape according to `dist`.
"""
function with_logabsdet_jacobian_and_reconstruct(f, dist, x)
x_recon = reconstruct(f, dist, x)
return with_logabsdet_jacobian(f, x_recon)
end

# TODO: Once `(inv)link` isn't used heavily in `getindex(vi, vn)`, we can
# just use `first ∘ with_logabsdet_jacobian` to reduce the maintenance burden.
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)

Return linked `val` but reconstruct before linking, if necessary.

Note that unlike [`invlink_and_reconstruct`](@ref), this does not necessarily
return a reconstructed value, i.e. a value of the same type and shape as expected
by `dist`.

See also: [`invlink_and_reconstruct`](@ref), [`reconstruct`](@ref).
"""
reconstruct_and_link(f, dist, val) = f(reconstruct(f, dist, val))
reconstruct_and_link(dist, val) = reconstruct_and_link(link_transform(dist), dist, val)
function reconstruct_and_link(::AbstractVarInfo, ::VarName, dist, val)
return reconstruct_and_link(dist, val)
end

"""
invlink_and_reconstruct(dist, val)
invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return invlinked and reconstructed `val`.

See also: [`reconstruct_and_link`](@ref), [`reconstruct`](@ref).
"""
invlink_and_reconstruct(f, dist, val) = f(reconstruct(f, dist, val))
function invlink_and_reconstruct(dist, val)
return invlink_and_reconstruct(invlink_transform(dist), dist, val)
end
function invlink_and_reconstruct(::AbstractVarInfo, ::VarName, dist, val)
return invlink_and_reconstruct(dist, val)
end

"""
maybe_link_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly linked if `istrans(vi, vn)` is `true`.
"""
function maybe_reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
reconstruct_and_link(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)

Return reconstructed `val`, possibly invlinked if `istrans(vi, vn)` is `true`.
"""
function maybe_invlink_and_reconstruct(vi::AbstractVarInfo, vn::VarName, dist, val)
return if istrans(vi, vn)
invlink_and_reconstruct(vi, vn, dist, val)
else
reconstruct(dist, val)
end
end

"""
invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist[, x])

Invlink `x` and compute the logpdf under `dist` including correction from
the invlink-transformation.

If `x` is not provided, `getval(vi, vn)` will be used.
"""
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist)
return invlink_with_logpdf(vi, vn, dist, getval(vi, vn))
end
function invlink_with_logpdf(vi::AbstractVarInfo, vn::VarName, dist, y)
# NOTE: Will this cause type-instabilities or will union-splitting save us?
f = istrans(vi, vn) ? invlink_transform(dist) : identity
x, logjac = with_logabsdet_jacobian_and_reconstruct(f, dist, y)
return x, logpdf(dist, x) + logjac
end

# Legacy code that is currently overloaded for the sake of simplicity.
# TODO: Remove when possible.
increment_num_produce!(::AbstractVarInfo) = nothing
Expand Down
30 changes: 9 additions & 21 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ end

# fallback without sampler
function assume(dist::Distribution, vn::VarName, vi)
r, logp = invlink_with_logpdf(vi, vn, dist)
return r, logp, vi
r = vi[vn, dist]
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
end

# SampleFromPrior and SampleFromUniform
Expand All @@ -211,9 +211,7 @@ function assume(
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(rng, dist, sampler)
BangBang.setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r)), vn
)
BangBang.setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r)), vn)
setorder!(vi, vn, get_num_produce(vi))
else
# Otherwise we just extract it.
Expand All @@ -222,17 +220,15 @@ function assume(
else
r = init(rng, dist, sampler)
if istrans(vi)
push!!(vi, vn, reconstruct_and_link(dist, r), dist, sampler)
push!!(vi, vn, link(dist, r), dist, sampler)
# By default `push!!` sets the transformed flag to `false`.
settrans!!(vi, true, vn)
else
push!!(vi, vn, r, dist, sampler)
end
end

# HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
logjac = logabsdetjac(istrans(vi, vn) ? link_transform(dist) : identity, r)
return r, logpdf(dist, r) - logjac, vi
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
end

# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
Expand Down Expand Up @@ -474,11 +470,7 @@ function get_and_set_val!(
r = init(rng, dist, spl, n)
for i in 1:n
vn = vns[i]
setindex!!(
vi,
vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[:, i])),
vn,
)
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[:, i])), vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
Expand Down Expand Up @@ -516,17 +508,13 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
setindex!!(
vi, vectorize(dist, maybe_reconstruct_and_link(vi, vn, dist, r[i])), vn
)
setindex!!(vi, vectorize(dist, maybe_link(vi, vn, dist, r[i])), vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
# r = reshape(vi[vec(vns)], size(vns))
# FIXME: Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
# and fix the lines below.
r_raw = getindex_raw(vi, vec(vns))
r = maybe_invlink_and_reconstruct.((vi,), vns, dists, reshape(r_raw, size(vns)))
r = maybe_invlink.((vi,), vns, dists, reshape(r_raw, size(vns)))
end
else
f = (vn, dist) -> init(rng, dist, spl)
Expand All @@ -537,7 +525,7 @@ function get_and_set_val!(
# 2. Define an anonymous function which returns `nothing`, which
# we then broadcast. This will allocate a vector of `nothing` though.
if istrans(vi)
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
push!!.((vi,), vns, link.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
Expand Down
17 changes: 7 additions & 10 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ end

# `NamedTuple`
function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
return maybe_invlink_and_reconstruct(vi, vn, dist, getindex(vi, vn))
return maybe_invlink(vi, vn, dist, getindex(vi, vn))
end
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
vals_linked = mapreduce(vcat, vns) do vn
Expand Down Expand Up @@ -329,9 +329,6 @@ function getindex_raw(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribut
return reconstruct(dist, vals, length(vns))
end

# HACK: because `VarInfo` isn't ready to implement a proper `getindex_raw`.
getval(vi::SimpleVarInfo, vn::VarName) = getindex_raw(vi, vn)

Base.haskey(vi::SimpleVarInfo, vn::VarName) = hasvalue(vi.values, vn)

function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName)
Expand Down Expand Up @@ -429,7 +426,7 @@ function assume(
)
value = init(rng, dist, sampler)
# Transform if we're working in unconstrained space.
value_raw = maybe_reconstruct_and_link(vi, vn, dist, value)
value_raw = maybe_link(vi, vn, dist, value)
vi = BangBang.push!!(vi, vn, value_raw, dist, sampler)
return value, Bijectors.logpdf_with_trans(dist, value, istrans(vi, vn)), vi
end
Expand All @@ -447,9 +444,9 @@ function dot_assume(

# Transform if we're working in transformed space.
value_raw = if dists isa Distribution
maybe_reconstruct_and_link.((vi,), vns, (dists,), value)
maybe_link.((vi,), vns, (dists,), value)
else
maybe_reconstruct_and_link.((vi,), vns, dists, value)
maybe_link.((vi,), vns, dists, value)
end

# Update `vi`
Expand All @@ -476,7 +473,7 @@ function dot_assume(

# Update `vi`.
for (vn, val) in zip(vns, eachcol(value))
val_linked = maybe_reconstruct_and_link(vi, vn, dist, val)
val_linked = maybe_link(vi, vn, dist, val)
vi = BangBang.setindex!!(vi, val_linked, vn)
end

Expand All @@ -491,7 +488,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {
nt_vals = map(keys(vi)) do vn
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)
vals = map(Base.Fix1(getindex, vi), vns)
(vals, map(string, vns))
end

Expand All @@ -504,7 +501,7 @@ function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict})
# Extract the leaf varnames and values.
val = vi[vn]
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val))
vals = map(copy ∘ Base.Fix1(getindex, vi), vns)
vals = map(Base.Fix1(getindex, vi), vns)

# Determine the corresponding symbol.
sym = only(unique(map(getsym, vns)))
Expand Down
2 changes: 0 additions & 2 deletions src/threadsafe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,3 @@ end

istrans(vi::ThreadSafeVarInfo, vn::VarName) = istrans(vi.varinfo, vn)
istrans(vi::ThreadSafeVarInfo, vns::AbstractVector{<:VarName}) = istrans(vi.varinfo, vns)

getval(vi::ThreadSafeVarInfo, vn::VarName) = getval(vi.varinfo, vn)
6 changes: 3 additions & 3 deletions src/transforming.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function tilde_assume(

# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
r_transformed = isinverse ? r : link_transform(right)(r)
r_transformed = isinverse ? r : bijector(right)(r)
return r, lp, setindex!!(vi, r_transformed, vn)
end

Expand All @@ -27,7 +27,7 @@ function dot_tilde_assume(
vi,
) where {isinverse}
r = getindex.((vi,), vns, (dist,))
b = link_transform(dist)
b = bijector(dist)

is_trans_uniques = unique(istrans.((vi,), vns))
@assert length(is_trans_uniques) == 1 "DynamicTransformationContext only supports transforming all variables"
Expand Down Expand Up @@ -70,7 +70,7 @@ function dot_tilde_assume(
@assert !isinverse "Trying to invlink non-transformed variables"
end

b = link_transform(dist)
b = bijector(dist)
for (vn, ri) in zip(vns, eachcol(r))
# Only transform if `!isinverse` since `vi[vn, right]`
# already performs the inverse transformation if it's transformed.
Expand Down
47 changes: 1 addition & 46 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,39 +177,10 @@ function to_namedtuple_expr(syms, vals)
return :(NamedTuple{$names_expr}($vals_expr))
end

"""
link_transform(dist)

Return the constrained-to-unconstrained bijector for distribution `dist`.

By default, this is just `Bijectors.bijector(dist)`.

!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
link_transform(dist) = bijector(dist)

"""
invlink_transform(dist)

Return the unconstrained-to-constrained bijector for distribution `dist`.

By default, this is just `inverse(link_transform(dist))`.

!!! warning
Note that currently this is not used by `Bijectors.logpdf_with_trans`,
hence that needs to be overloaded separately if the intention is
to change behavior of an existing distribution.
"""
invlink_transform(dist) = inverse(link_transform(dist))

#####################################################
# Helper functions for vectorize/reconstruct values #
#####################################################

vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
Expand All @@ -220,23 +191,7 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
# otherwise we will have error for MatrixDistribution.
# Note this is not the case for MultivariateDistribution so I guess this might be lack of
# support for some types related to matrices (like PDMat).

"""
reconstruct([f, ]dist, val)

Reconstruct `val` so that it's compatible with `dist`.

If `f` is also provided, the reconstruct value will be
such that `f(reconstruct_val)` is compatible with `dist`.
"""
reconstruct(f, dist, val) = reconstruct(dist, val)

# No-op versions.
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::UnivariateDistribution, val::Real) = val
reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
reconstruct(::Tuple{}, val::AbstractVector) = val[1]
reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val)
Expand Down
Loading