Skip to content

Simplify Turing.Variational and fix test error #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Expand All @@ -46,7 +45,6 @@ Libtask = "0.4"
LogDensityProblems = "^0.9, 0.10"
MCMCChains = "4"
NamedArrays = "0.9"
ProgressLogging = "0.1"
Reexport = "0.2.0"
Requires = "0.5, 1.0"
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10"
Expand Down
2 changes: 2 additions & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ using Libtask
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
using Tracker: Tracker

import AdvancedVI
import DynamicPPL: getspace, NoDist, NamedDist

const PROGRESS = Ref(true)
function turnprogress(switch::Bool)
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
PROGRESS[] = switch
AdvancedVI.turnprogress(switch)
end

# Random probability measures.
Expand Down
45 changes: 19 additions & 26 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
module Variational

using ..Core, ..Utilities
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
using Distributions, Bijectors, DynamicPPL
using LinearAlgebra
using ..Turing: PROGRESS, Turing
using DynamicPPL: Model, SampleFromPrior, SampleFromUniform
using Random: AbstractRNG
import AdvancedVI
import Bijectors
import DistributionsAD
import DynamicPPL
import StatsBase
import StatsFuns

using ForwardDiff
using Tracker

import ..Core: getchunksize, getADbackend

import AbstractMCMC
import ProgressLogging

using AdvancedVI
import Random

# Reexports
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad
export
vi,
ADVI,
Expand All @@ -34,38 +27,38 @@ use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_o
## Notes
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`.
"""
function make_logjoint(model::Model; weight = 1.0)
function make_logjoint(model::DynamicPPL.Model; weight = 1.0)
# setup
ctx = DynamicPPL.MiniBatchContext(
DynamicPPL.DefaultContext(),
weight
)
varinfo_init = Turing.VarInfo(model, ctx)
varinfo_init = DynamicPPL.VarInfo(model, ctx)

function logπ(z)
varinfo = VarInfo(varinfo_init, SampleFromUniform(), z)
varinfo = DynamicPPL.VarInfo(varinfo_init, DynamicPPL.SampleFromUniform(), z)
model(varinfo)

return getlogp(varinfo)
return DynamicPPL.getlogp(varinfo)
end

return logπ
end

function logjoint(model::Model, varinfo, z)
varinfo = VarInfo(varinfo, SampleFromUniform(), z)
function logjoint(model::DynamicPPL.Model, varinfo, z)
varinfo = DynamicPPL.VarInfo(varinfo, DynamicPPL.SampleFromUniform(), z)
model(varinfo)

return getlogp(varinfo)
return DynamicPPL.getlogp(varinfo)
end


# objectives
function (elbo::ELBO)(
rng::AbstractRNG,
alg::VariationalInference,
rng::Random.AbstractRNG,
alg::AdvancedVI.VariationalInference,
q,
model::Model,
model::DynamicPPL.Model,
num_samples;
weight = 1.0,
kwargs...
Expand Down
79 changes: 46 additions & 33 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
using StatsFuns
using DistributionsAD
using Bijectors
using Bijectors: TransformedDistribution
using Random: AbstractRNG, GLOBAL_RNG
import Bijectors: bijector

"""
bijector(model::Model; sym_to_ranges = Val(false))
bijector(model::Model[, sym2ranges = Val(false)])

Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
denoting the dimensionality of the latent variables.
"""
function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) where {sym2ranges}
varinfo = Turing.VarInfo(model)
function Bijectors.bijector(
model::DynamicPPL.Model,
::Val{sym2ranges} = Val(false),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noted that somehow I switched from keyword arguments to positional arguments (I guess since otherwise the name of the argument can't be omitted) - I guess we want to keep the keyword argument version?

) where {sym2ranges}
varinfo = DynamicPPL.VarInfo(model)
num_params = sum([size(varinfo.metadata[sym].vals, 1)
for sym ∈ keys(varinfo.metadata)])

Expand All @@ -37,25 +33,27 @@ function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) whe
idx += varinfo.metadata[sym].ranges[end][end]
end

bs = bijector.(tuple(dists...))
bs = Bijectors.bijector.(tuple(dists...))

if sym2ranges
return Stacked(bs, ranges), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
return (
Bijectors.Stacked(bs, ranges),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
)
else
return Stacked(bs, ranges)
return Bijectors.Stacked(bs, ranges)
end
end

"""
meanfield(model::Model)
meanfield(rng::AbstractRNG, model::Model)
meanfield([rng, ]model::Model)

Creates a mean-field approximation with multivariate normal as underlying distribution.
"""
meanfield(model::Model) = meanfield(GLOBAL_RNG, model)
function meanfield(rng::AbstractRNG, model::Model)
meanfield(model::DynamicPPL.Model) = meanfield(Random.GLOBAL_RNG, model)
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
# setup
varinfo = Turing.VarInfo(model)
varinfo = DynamicPPL.VarInfo(model)
num_params = sum([size(varinfo.metadata[sym].vals, 1)
for sym ∈ keys(varinfo.metadata)])

Expand All @@ -71,43 +69,58 @@ function meanfield(rng::AbstractRNG, model::Model)
ranges[range_idx] = idx .+ r
range_idx += 1
end

# append!(ranges, [idx .+ r for r ∈ varinfo.metadata[sym].ranges])
idx += varinfo.metadata[sym].ranges[end][end]
end

# initial params
μ = randn(rng, num_params)
σ = softplus.(randn(rng, num_params))
σ = StatsFuns.softplus.(randn(rng, num_params))

# construct variational posterior
d = TuringDiagMvNormal(μ, σ)
bs = inv.(bijector.(tuple(dists...)))
b = Stacked(bs, ranges)
d = DistributionsAD.TuringDiagMvNormal(μ, σ)
bs = inv.(Bijectors.bijector.(tuple(dists...)))
b = Bijectors.Stacked(bs, ranges)

return transformed(d, b)
return Bijectors.transformed(d, b)
end


# Overloading stuff from `AdvancedVI` to specialize for Turing
AdvancedVI.update(d::TuringDiagMvNormal, μ, σ) = TuringDiagMvNormal(μ, σ)
AdvancedVI.update(td::TransformedDistribution, θ...) = transformed(AdvancedVI.update(td.dist, θ...), td.transform)
function AdvancedVI.update(td::TransformedDistribution{<:TuringDiagMvNormal}, θ::AbstractArray)
function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ)
return DistributionsAD.TuringDiagMvNormal(μ, σ)
end
function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...)
return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform)
end
function AdvancedVI.update(
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
θ::AbstractArray,
)
Comment on lines +90 to +99
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these functions be moved to AdvancedVI (maybe by using Requires)? They are not specific to Turing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial idea was to not have these implemented in AVI because they are a particular choice of updating a particular distribution. In general there might be multiple different ways to "update" the parameters of a distribution, hence I didn't want to make that choice in AVI but instead leave it to the user (which in this case kind of is Turing).

It's not a nice approach though, but at the time it wasn't quite clear to me how to handle the problem in general. And it honestly still isn't 😕 AVI also supports a simple mapping from parameters to distribution, which is the "preferred" way.

But we could/should at least move the above impls for TransformedDistribution (i.e. not specific to TuringDiagMvNormal) to AVI. For the rest, I'm a bit uncertain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure since I don't know the internals of AdvancedVI. In general, one approach that makes use of multiple dispatch but still allows specializations by users without too much effort or unintended side effects would be something like

struct DefaultVIMethod end

update(method, args...; kwargs...) = update(DefaultVIMethod(), args...; kwargs...)

function update(::DefaultVIMethod, d::TransformedDistribution, theta...)
    # default implementation
    ...
end

in AdvancedVI. Internally one would always call update with an additional method argument (e.g., using a DefaultVIMethod singleton by default). Then users/Turing could define something like

struct MyVIMethod end

function update(::MyVIMethod, d::TransformedDistribution, theta...)
    # special implementation
    ...
end

and specify that they want to use MyVIMethod instead, which would then be used in the update step whenever a special implementation for MyVIMethod is available (otherwise the fallback for DefaultVIMethod would be used).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This update method is really completely independent of VI; it's simply a function for taking a distribution and a set of parameters, and returning a new distribution of the same type but with the new parameters.

It's similar to the reconstruct method I think we had (or still have?) in Turing? I'd almost prefer to just remove it and make the user provide the mapping from parameters to distribution directly (this is already a possibility), but it can be difficult for a user to do properly + everything would end up hard-coded to specific distributions. Feels like there should be a better way of doing it though.

Copy link
Member Author

@devmotion devmotion Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is DynamicPPL.reconstruct IIRC.

Wouldn't the approach above work even if it's independent from AdvancedVI? The problem with just overloading is that since there are multiple possible ways of reconstructing a distribution different approaches (in possibly different packages) would lead to method redefinitions. With an additional argument for dispatch that could be avoided, and it would be easier to define and use different approaches. And it would still allow to define some default fallback, which would be more difficult to achieve if it would just be some user-defined function.

An alternative would be to provide some default reconstructions and use them by default (as argument to some VI method), but allow a user to specify her own mappings as well. Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it would solve the problem, but it would end up quite nasty imo. E.g.

abstract type AbstractDistUpdate end
struct DefaultDistUpdate <: AbstractDistUpdate end

update(d::Distribution, theta...) = update(DefaultDistUpdate(), d, theta...)
update(ut::DefaultDistUpdate, d::Distribution, theta...) = ...

and so on.

Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.

I actually had this at some point but it was something awful, e.g.

@generated function update(d::Distribution, θ...)
    return :($(esc(nameof(d)))(θ...))
end

Is that something you're thinking of?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that something you're thinking of?

No, I was thinking more about how to simplify the user implementations, it seems your suggestion would rather help to define the default fallbacks (which a user shouldn't have to care about). Since you mentioned that it's already possible to provide custom reconstructions I assumed it would be possible to call some VI-related method in the following way

vimethod(reconstruct, moreargs...; kwargs...)

where reconstruct is a function that defines the reconstructions. One could then define default reconstructions in AdvancedVI (or some other package) and use them by default, i.e., something like

# I'm not sure if it is helpful to use a generated function here
function reconstruct(d::Distribution, theta...)
    # maybe want to https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75 here
    return typeof(d)(theta...)
end

vimethod(moreargs...; kwargs...) = vimethod(reconstruct, moreargs...; kwargs...)

Then a user/package could define

# fallback (might not always be needed)
myreconstruct(d::Distribution, theta...) = AdvancedVI.reconstruct(d, theta...)

function myreconstruct(d::SomeDistribution, theta)
    # custom reconstruction
    ...
end

which would be used when calling vimethod(myreconstruct, moreargs...; kwargs...). In this approach one does not have to define the singletons but needs to define the fallback explicitly (however, it might not always be needed).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you mean. It still ends up with the annoying "so I have to define specific reconstruct for each distribution?". I guess that is more of an issue with Distributions.jl rather than AVI. But you've convinced me; the above seems like it would be easier for packages to work with 👍

How about we merge this PR though, and I'll open an issue over at AVI referring to this discussion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I forgot to mention that it's necessary to use @generated in update because it might not have the same type as the original distribution, e.g. when using AD.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
return AdvancedVI.update(td, μ, softplus.(ω))
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
end

function AdvancedVI.vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad())
function AdvancedVI.vi(
model::DynamicPPL.Model,
alg::AdvancedVI.ADVI;
optimizer = AdvancedVI.TruncatedADAGrad(),
)
Comment on lines +104 to +108
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this one and the one below as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not quite clear to me why we want to move this AVI. In my mind Turing depends on AVI for it's VI-functionality, and so Turing should be the one who overloads accordingly to ensure that the types used by Turing work nicely with AVI (which in this case is DynamicPPL.Model).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm less certain about this one. It's just that neither the function nor any of its arguments belong to Turing, which is a bit strange (and type piracy 🏴‍☠️)

Copy link
Member

@torfjelde torfjelde Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this more an artifact of Turing now mostly being just a "shell" / "interface" for a lot of upstream packages? Though I agree type-piracy is bad, it seems like we're inevitably going to have to do a bit of it in Turing unless we want to construct a ton of wrapper structs with no real additional functionality.

Though I agree Requires.jl is an alternative, but it also means that packages that are really upstream for Turing now need to be the ones maintaining compatibility rather than the other way around. I guess that's okay since we're the ones maintaining all the packages, but I feel like the significant use of Requires.jl will come back and bite us at some point. It's already kind of difficult to track down all the upstream packages which require change. In particular it does kind of break the "separation of concerns" as maintainers now need a fairly good knowledge of all the packages involved.

I guess my point is that the type-piracy vs. Requires.jl isn't a clear-cut "always go with Requires.jl"? Maaaybe that's just me though 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm not convinced either, and in general I'm not a big fan of Requires.jl, it leads to all kind of issues with compatibilities, multiple optional dependencies, package load times, and testing. IMO the Julia package ecosystem still has to come up with a better solution for this general problem.

q = meanfield(model)
return AdvancedVI.vi(model, alg, q; optimizer = optimizer)
end


function AdvancedVI.vi(model::Model, alg::ADVI, q::TransformedDistribution{<:TuringDiagMvNormal}; optimizer = TruncatedADAGrad())
function AdvancedVI.vi(
model::DynamicPPL.Model,
alg::AdvancedVI.ADVI,
q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal};
optimizer = AdvancedVI.TruncatedADAGrad(),
)
@debug "Optimizing ADVI..."
# Initial parameters for mean-field approx
μ, σs = params(q)
θ = vcat(μ, invsoftplus.(σs))
μ, σs = StatsBase.params(q)
θ = vcat(μ, StatsFuns.invsoftplus.(σs))

# Optimize
AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer)
Expand Down
21 changes: 0 additions & 21 deletions src/variational/objectives.jl

This file was deleted.

94 changes: 0 additions & 94 deletions src/variational/optimisers.jl

This file was deleted.

2 changes: 1 addition & 1 deletion test/variational/optimisers.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Random, Test, LinearAlgebra, ForwardDiff
using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, apply!
using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply!

function test_opt(ADPack, opt)
θ = randn(10, 10)
Expand Down