-
Notifications
You must be signed in to change notification settings - Fork 227
Simplify Turing.Variational and fix test error #1377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,14 @@ | ||
using StatsFuns | ||
using DistributionsAD | ||
using Bijectors | ||
using Bijectors: TransformedDistribution | ||
using Random: AbstractRNG, GLOBAL_RNG | ||
import Bijectors: bijector | ||
|
||
""" | ||
bijector(model::Model; sym_to_ranges = Val(false)) | ||
bijector(model::Model[, sym2ranges = Val(false)]) | ||
|
||
Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` | ||
denoting the dimensionality of the latent variables. | ||
""" | ||
function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) where {sym2ranges} | ||
varinfo = Turing.VarInfo(model) | ||
function Bijectors.bijector( | ||
model::DynamicPPL.Model, | ||
::Val{sym2ranges} = Val(false), | ||
) where {sym2ranges} | ||
varinfo = DynamicPPL.VarInfo(model) | ||
num_params = sum([size(varinfo.metadata[sym].vals, 1) | ||
for sym ∈ keys(varinfo.metadata)]) | ||
|
||
|
@@ -37,25 +33,27 @@ function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) whe | |
idx += varinfo.metadata[sym].ranges[end][end] | ||
end | ||
|
||
bs = bijector.(tuple(dists...)) | ||
bs = Bijectors.bijector.(tuple(dists...)) | ||
|
||
if sym2ranges | ||
return Stacked(bs, ranges), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...) | ||
return ( | ||
Bijectors.Stacked(bs, ranges), | ||
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...), | ||
) | ||
else | ||
return Stacked(bs, ranges) | ||
return Bijectors.Stacked(bs, ranges) | ||
end | ||
end | ||
|
||
""" | ||
meanfield(model::Model) | ||
meanfield(rng::AbstractRNG, model::Model) | ||
meanfield([rng, ]model::Model) | ||
|
||
Creates a mean-field approximation with multivariate normal as underlying distribution. | ||
""" | ||
meanfield(model::Model) = meanfield(GLOBAL_RNG, model) | ||
function meanfield(rng::AbstractRNG, model::Model) | ||
meanfield(model::DynamicPPL.Model) = meanfield(Random.GLOBAL_RNG, model) | ||
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) | ||
# setup | ||
varinfo = Turing.VarInfo(model) | ||
varinfo = DynamicPPL.VarInfo(model) | ||
num_params = sum([size(varinfo.metadata[sym].vals, 1) | ||
for sym ∈ keys(varinfo.metadata)]) | ||
|
||
|
@@ -71,43 +69,58 @@ function meanfield(rng::AbstractRNG, model::Model) | |
ranges[range_idx] = idx .+ r | ||
range_idx += 1 | ||
end | ||
|
||
# append!(ranges, [idx .+ r for r ∈ varinfo.metadata[sym].ranges]) | ||
idx += varinfo.metadata[sym].ranges[end][end] | ||
end | ||
|
||
# initial params | ||
μ = randn(rng, num_params) | ||
σ = softplus.(randn(rng, num_params)) | ||
σ = StatsFuns.softplus.(randn(rng, num_params)) | ||
|
||
# construct variational posterior | ||
d = TuringDiagMvNormal(μ, σ) | ||
bs = inv.(bijector.(tuple(dists...))) | ||
b = Stacked(bs, ranges) | ||
d = DistributionsAD.TuringDiagMvNormal(μ, σ) | ||
bs = inv.(Bijectors.bijector.(tuple(dists...))) | ||
b = Bijectors.Stacked(bs, ranges) | ||
|
||
return transformed(d, b) | ||
return Bijectors.transformed(d, b) | ||
end | ||
|
||
|
||
# Overloading stuff from `AdvancedVI` to specialize for Turing | ||
AdvancedVI.update(d::TuringDiagMvNormal, μ, σ) = TuringDiagMvNormal(μ, σ) | ||
AdvancedVI.update(td::TransformedDistribution, θ...) = transformed(AdvancedVI.update(td.dist, θ...), td.transform) | ||
function AdvancedVI.update(td::TransformedDistribution{<:TuringDiagMvNormal}, θ::AbstractArray) | ||
function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) | ||
return DistributionsAD.TuringDiagMvNormal(μ, σ) | ||
end | ||
function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) | ||
return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) | ||
end | ||
function AdvancedVI.update( | ||
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, | ||
θ::AbstractArray, | ||
) | ||
Comment on lines
+90
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could these functions be moved to AdvancedVI (maybe by using Requires)? They are not specific to Turing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The initial idea was to not have these implemented in AVI because they are a particular choice of updating a particular distribution. In general there might be multiple different ways to "update" the parameters of a distribution, hence I didn't want to make that choice in AVI but instead leave it to the user (which in this case kind of is Turing). It's not a nice approach though, but at the time it wasn't quite clear to me how to handle the problem in general. And it honestly still isn't 😕 AVI also supports a simple mapping from parameters to distribution, which is the "preferred" way. But we could/should at least move the above impls for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure since I don't know the internals of AdvancedVI. In general, one approach that makes use of multiple dispatch but still allows specializations by users without too much effort or unintended side effects would be something like struct DefaultVIMethod end
update(method, args...; kwargs...) = update(DefaultVIMethod(), args...; kwargs...)
function update(::DefaultVIMethod, d::TransformedDistribution, theta...)
# default implementation
...
end in AdvancedVI. Internally one would always call struct MyVIMethod end
function update(::MyVIMethod, d::TransformedDistribution, theta...)
# special implementation
...
end and specify that they want to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This It's similar to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is Wouldn't the approach above work even if it's independent from AdvancedVI? The problem with just overloading is that since there are multiple possible ways of reconstructing a distribution different approaches (in possibly different packages) would lead to method redefinitions. With an additional argument for dispatch that could be avoided, and it would be easier to define and use different approaches. And it would still allow to define some default fallback, which would be more difficult to achieve if it would just be some user-defined function. An alternative would be to provide some default reconstructions and use them by default (as argument to some VI method), but allow a user to specify her own mappings as well. Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, it would solve the problem, but it would end up quite nasty imo. E.g. abstract type AbstractDistUpdate end
struct DefaultDistUpdate <: AbstractDistUpdate end
update(d::Distribution, theta...) = update(DefaultDistUpdate(), d, theta...)
update(ut::DefaultDistUpdate, d::Distribution, theta...) = ... and so on.
I actually had this at some point but it was something awful, e.g. @generated function update(d::Distribution, θ...)
return :($(esc(nameof(d)))(θ...))
end Is that something you're thinking of? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
No, I was thinking more about how to simplify the user implementations, it seems your suggestion would rather help to define the default fallbacks (which a user shouldn't have to care about). Since you mentioned that it's already possible to provide custom reconstructions I assumed it would be possible to call some VI-related method in the following way vimethod(reconstruct, moreargs...; kwargs...) where # I'm not sure if it is helpful to use a generated function here
function reconstruct(d::Distribution, theta...)
# maybe want to https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75 here
return typeof(d)(theta...)
end
vimethod(moreargs...; kwargs...) = vimethod(reconstruct, moreargs...; kwargs...) Then a user/package could define # fallback (might not always be needed)
myreconstruct(d::Distribution, theta...) = AdvancedVI.reconstruct(d, theta...)
function myreconstruct(d::SomeDistribution, theta)
# custom reconstruction
...
end which would be used when calling There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see what you mean. It still ends up with the annoying "so I have to define specific How about we merge this PR though, and I'll open an issue over at AVI referring to this discussion? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, I forgot to mention that it's necessary to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that's why I suggested to use https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75 |
||
μ, ω = θ[1:length(td)], θ[length(td) + 1:end] | ||
return AdvancedVI.update(td, μ, softplus.(ω)) | ||
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) | ||
end | ||
|
||
function AdvancedVI.vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad()) | ||
function AdvancedVI.vi( | ||
model::DynamicPPL.Model, | ||
alg::AdvancedVI.ADVI; | ||
optimizer = AdvancedVI.TruncatedADAGrad(), | ||
) | ||
Comment on lines
+104
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this one and the one below as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, not quite clear to me why we want to move this AVI. In my mind Turing depends on AVI for it's VI-functionality, and so Turing should be the one who overloads accordingly to ensure that the types used by Turing work nicely with AVI (which in this case is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I'm less certain about this one. It's just that neither the function nor any of its arguments belong to Turing, which is a bit strange (and type piracy 🏴☠️) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But isn't this more an artifact of Turing now mostly being just a "shell" / "interface" for a lot of upstream packages? Though I agree type-piracy is bad, it seems like we're inevitably going to have to do a bit of it in Turing unless we want to construct a ton of wrapper structs with no real additional functionality. Though I agree Requires.jl is an alternative, but it also means that packages that are really upstream for Turing now need to be the ones maintaining compatibility rather than the other way around. I guess that's okay since we're the ones maintaining all the packages, but I feel like the significant use of Requires.jl will come back and bite us at some point. It's already kind of difficult to track down all the upstream packages which require change. In particular it does kind of break the "separation of concerns" as maintainers now need a fairly good knowledge of all the packages involved. I guess my point is that the type-piracy vs. Requires.jl isn't a clear-cut "always go with Requires.jl"? Maaaybe that's just me though 🤷 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I'm not convinced either, and in general I'm not a big fan of Requires.jl, it leads to all kind of issues with compatibilities, multiple optional dependencies, package load times, and testing. IMO the Julia package ecosystem still has to come up with a better solution for this general problem. |
||
q = meanfield(model) | ||
return AdvancedVI.vi(model, alg, q; optimizer = optimizer) | ||
end | ||
|
||
|
||
function AdvancedVI.vi(model::Model, alg::ADVI, q::TransformedDistribution{<:TuringDiagMvNormal}; optimizer = TruncatedADAGrad()) | ||
function AdvancedVI.vi( | ||
model::DynamicPPL.Model, | ||
alg::AdvancedVI.ADVI, | ||
q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; | ||
optimizer = AdvancedVI.TruncatedADAGrad(), | ||
) | ||
@debug "Optimizing ADVI..." | ||
# Initial parameters for mean-field approx | ||
μ, σs = params(q) | ||
θ = vcat(μ, invsoftplus.(σs)) | ||
μ, σs = StatsBase.params(q) | ||
θ = vcat(μ, StatsFuns.invsoftplus.(σs)) | ||
|
||
# Optimize | ||
AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer) | ||
|
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noted that somehow I switched from keyword arguments to positional arguments (I guess since otherwise the name of the argument can't be omitted) - I guess we want to keep the keyword argument version?