Skip to content

Simplify Turing.Variational and fix test error #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 18, 2020
Merged

Conversation

devmotion
Copy link
Member

This PR removes some things that were left (but not used anymore) after #1362 and fixes a test error of Turing.Variational. Moreover, unnecessary package imports are removed. Switching from using to import points out more clearly to which packages used functions and types belong (Turing or its submodules are not needed at all).

Comment on lines +90 to +99
function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ)
return DistributionsAD.TuringDiagMvNormal(μ, σ)
end
function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...)
return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform)
end
function AdvancedVI.update(
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
θ::AbstractArray,
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could these functions be moved to AdvancedVI (maybe by using Requires)? They are not specific to Turing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The initial idea was to not have these implemented in AVI because they are a particular choice of updating a particular distribution. In general there might be multiple different ways to "update" the parameters of a distribution, hence I didn't want to make that choice in AVI but instead leave it to the user (which in this case kind of is Turing).

It's not a nice approach though, but at the time it wasn't quite clear to me how to handle the problem in general. And it honestly still isn't 😕 AVI also supports a simple mapping from parameters to distribution, which is the "preferred" way.

But we could/should at least move the above impls for TransformedDistribution (i.e. not specific to TuringDiagMvNormal) to AVI. For the rest, I'm a bit uncertain.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure since I don't know the internals of AdvancedVI. In general, one approach that makes use of multiple dispatch but still allows specializations by users without too much effort or unintended side effects would be something like

struct DefaultVIMethod end

update(method, args...; kwargs...) = update(DefaultVIMethod(), args...; kwargs...)

function update(::DefaultVIMethod, d::TransformedDistribution, theta...)
    # default implementation
    ...
end

in AdvancedVI. Internally one would always call update with an additional method argument (e.g., using a DefaultVIMethod singleton by default). Then users/Turing could define something like

struct MyVIMethod end

function update(::MyVIMethod, d::TransformedDistribution, theta...)
    # special implementation
    ...
end

and specify that they want to use MyVIMethod instead, which would then be used in the update step whenever a special implementation for MyVIMethod is available (otherwise the fallback for DefaultVIMethod would be used).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This update method is really completely independent of VI; it's simply a function for taking a distribution and a set of parameters, and returning a new distribution of the same type but with the new parameters.

It's similar to the reconstruct method I think we had (or still have?) in Turing? I'd almost prefer to just remove it and make the user provide the mapping from parameters to distribution directly (this is already a possibility), but it can be difficult for a user to do properly + everything would end up hard-coded to specific distributions. Feels like there should be a better way of doing it though.

Copy link
Member Author

@devmotion devmotion Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is DynamicPPL.reconstruct IIRC.

Wouldn't the approach above work even if it's independent from AdvancedVI? The problem with just overloading is that since there are multiple possible ways of reconstructing a distribution different approaches (in possibly different packages) would lead to method redefinitions. With an additional argument for dispatch that could be avoided, and it would be easier to define and use different approaches. And it would still allow to define some default fallback, which would be more difficult to achieve if it would just be some user-defined function.

An alternative would be to provide some default reconstructions and use them by default (as argument to some VI method), but allow a user to specify her own mappings as well. Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it would solve the problem, but it would end up quite nasty imo. E.g.

abstract type AbstractDistUpdate end
struct DefaultDistUpdate <: AbstractDistUpdate end

update(d::Distribution, theta...) = update(DefaultDistUpdate(), d, theta...)
update(ut::DefaultDistUpdate, d::Distribution, theta...) = ...

and so on.

Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.

I actually had this at some point but it was something awful, e.g.

@generated function update(d::Distribution, θ...)
    return :($(esc(nameof(d)))(θ...))
end

Is that something you're thinking of?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that something you're thinking of?

No, I was thinking more about how to simplify the user implementations, it seems your suggestion would rather help to define the default fallbacks (which a user shouldn't have to care about). Since you mentioned that it's already possible to provide custom reconstructions I assumed it would be possible to call some VI-related method in the following way

vimethod(reconstruct, moreargs...; kwargs...)

where reconstruct is a function that defines the reconstructions. One could then define default reconstructions in AdvancedVI (or some other package) and use them by default, i.e., something like

# I'm not sure if it is helpful to use a generated function here
function reconstruct(d::Distribution, theta...)
    # maybe want to https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75 here
    return typeof(d)(theta...)
end

vimethod(moreargs...; kwargs...) = vimethod(reconstruct, moreargs...; kwargs...)

Then a user/package could define

# fallback (might not always be needed)
myreconstruct(d::Distribution, theta...) = AdvancedVI.reconstruct(d, theta...)

function myreconstruct(d::SomeDistribution, theta)
    # custom reconstruction
    ...
end

which would be used when calling vimethod(myreconstruct, moreargs...; kwargs...). In this approach one does not have to define the singletons but needs to define the fallback explicitly (however, it might not always be needed).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see what you mean. It still ends up with the annoying "so I have to define specific reconstruct for each distribution?". I guess that is more of an issue with Distributions.jl rather than AVI. But you've convinced me; the above seems like it would be easier for packages to work with 👍

How about we merge this PR though, and I'll open an issue over at AVI referring to this discussion?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I forgot to mention that it's necessary to use @generated in update because it might not have the same type as the original distribution, e.g. when using AD.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +104 to +108
function AdvancedVI.vi(
model::DynamicPPL.Model,
alg::AdvancedVI.ADVI;
optimizer = AdvancedVI.TruncatedADAGrad(),
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this one and the one below as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, not quite clear to me why we want to move this AVI. In my mind Turing depends on AVI for it's VI-functionality, and so Turing should be the one who overloads accordingly to ensure that the types used by Turing work nicely with AVI (which in this case is DynamicPPL.Model).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm less certain about this one. It's just that neither the function nor any of its arguments belong to Turing, which is a bit strange (and type piracy 🏴‍☠️)

Copy link
Member

@torfjelde torfjelde Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But isn't this more an artifact of Turing now mostly being just a "shell" / "interface" for a lot of upstream packages? Though I agree type-piracy is bad, it seems like we're inevitably going to have to do a bit of it in Turing unless we want to construct a ton of wrapper structs with no real additional functionality.

Though I agree Requires.jl is an alternative, but it also means that packages that are really upstream for Turing now need to be the ones maintaining compatibility rather than the other way around. I guess that's okay since we're the ones maintaining all the packages, but I feel like the significant use of Requires.jl will come back and bite us at some point. It's already kind of difficult to track down all the upstream packages which require change. In particular it does kind of break the "separation of concerns" as maintainers now need a fairly good knowledge of all the packages involved.

I guess my point is that the type-piracy vs. Requires.jl isn't a clear-cut "always go with Requires.jl"? Maaaybe that's just me though 🤷

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm not convinced either, and in general I'm not a big fan of Requires.jl, it leads to all kind of issues with compatibilities, multiple optional dependencies, package load times, and testing. IMO the Julia package ecosystem still has to come up with a better solution for this general problem.

@codecov
Copy link

codecov bot commented Aug 17, 2020

Codecov Report

Merging #1377 into master will decrease coverage by 2.88%.
The diff coverage is 54.41%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1377      +/-   ##
==========================================
- Coverage   69.67%   66.79%   -2.89%     
==========================================
  Files          27       25       -2     
  Lines        1540     1605      +65     
==========================================
- Hits         1073     1072       -1     
- Misses        467      533      +66     
Impacted Files Coverage Δ
src/contrib/inference/AdvancedSMCExtensions.jl 0.00% <0.00%> (ø)
src/contrib/inference/sghmc.jl 0.00% <0.00%> (ø)
src/core/Core.jl 100.00% <ø> (ø)
src/inference/mh.jl 84.37% <ø> (+1.44%) ⬆️
src/variational/VariationalInference.jl 57.14% <37.50%> (-32.22%) ⬇️
src/inference/hmc.jl 81.65% <70.00%> (-3.87%) ⬇️
src/variational/advi.jl 60.00% <77.27%> (-3.24%) ⬇️
src/Turing.jl 100.00% <100.00%> (ø)
src/core/ad.jl 73.84% <100.00%> (+0.16%) ⬆️
src/inference/Inference.jl 86.30% <100.00%> (-5.11%) ⬇️
... and 21 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 5fe3648...3da1be9. Read the comment docs.

@devmotion devmotion mentioned this pull request Aug 17, 2020
varinfo = Turing.VarInfo(model)
function Bijectors.bijector(
model::DynamicPPL.Model,
::Val{sym2ranges} = Val(false),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noted that somehow I switched from keyword arguments to positional arguments (I guess since otherwise the name of the argument can't be omitted) - I guess we want to keep the keyword argument version?

@devmotion
Copy link
Member Author

Finally all tests are successful again 🎉 (Apart from the coverage...)

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice, many thanks @devmotion!

@devmotion devmotion requested a review from torfjelde August 18, 2020 09:38
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with the PR as it is, and then I suggest we leave the changes to AVI we discussed for a later PR.

Awesome stuff @devmotion !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants