-
Notifications
You must be signed in to change notification settings - Fork 227
Simplify Turing.Variational and fix test error #1377
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) | ||
return DistributionsAD.TuringDiagMvNormal(μ, σ) | ||
end | ||
function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) | ||
return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) | ||
end | ||
function AdvancedVI.update( | ||
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, | ||
θ::AbstractArray, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could these functions be moved to AdvancedVI (maybe by using Requires)? They are not specific to Turing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The initial idea was to not have these implemented in AVI because they are a particular choice of updating a particular distribution. In general there might be multiple different ways to "update" the parameters of a distribution, hence I didn't want to make that choice in AVI but instead leave it to the user (which in this case kind of is Turing).
It's not a nice approach though, but at the time it wasn't quite clear to me how to handle the problem in general. And it honestly still isn't 😕 AVI also supports a simple mapping from parameters to distribution, which is the "preferred" way.
But we could/should at least move the above impls for TransformedDistribution
(i.e. not specific to TuringDiagMvNormal
) to AVI. For the rest, I'm a bit uncertain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure since I don't know the internals of AdvancedVI. In general, one approach that makes use of multiple dispatch but still allows specializations by users without too much effort or unintended side effects would be something like
struct DefaultVIMethod end
update(method, args...; kwargs...) = update(DefaultVIMethod(), args...; kwargs...)
function update(::DefaultVIMethod, d::TransformedDistribution, theta...)
# default implementation
...
end
in AdvancedVI. Internally one would always call update
with an additional method
argument (e.g., using a DefaultVIMethod
singleton by default). Then users/Turing could define something like
struct MyVIMethod end
function update(::MyVIMethod, d::TransformedDistribution, theta...)
# special implementation
...
end
and specify that they want to use MyVIMethod
instead, which would then be used in the update step whenever a special implementation for MyVIMethod
is available (otherwise the fallback for DefaultVIMethod
would be used).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This update
method is really completely independent of VI; it's simply a function for taking a distribution and a set of parameters, and returning a new distribution of the same type but with the new parameters.
It's similar to the reconstruct
method I think we had (or still have?) in Turing? I'd almost prefer to just remove it and make the user provide the mapping from parameters to distribution directly (this is already a possibility), but it can be difficult for a user to do properly + everything would end up hard-coded to specific distributions. Feels like there should be a better way of doing it though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is DynamicPPL.reconstruct
IIRC.
Wouldn't the approach above work even if it's independent from AdvancedVI? The problem with just overloading is that since there are multiple possible ways of reconstructing a distribution different approaches (in possibly different packages) would lead to method redefinitions. With an additional argument for dispatch that could be avoided, and it would be easier to define and use different approaches. And it would still allow to define some default fallback, which would be more difficult to achieve if it would just be some user-defined function.
An alternative would be to provide some default reconstructions and use them by default (as argument to some VI method), but allow a user to specify her own mappings as well. Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, it would solve the problem, but it would end up quite nasty imo. E.g.
abstract type AbstractDistUpdate end
struct DefaultDistUpdate <: AbstractDistUpdate end
update(d::Distribution, theta...) = update(DefaultDistUpdate(), d, theta...)
update(ut::DefaultDistUpdate, d::Distribution, theta...) = ...
and so on.
Then a user would have to make sure that the function she provides is defined for all desired use cases, by e.g. falling back on the provided defaults if needed.
I actually had this at some point but it was something awful, e.g.
@generated function update(d::Distribution, θ...)
return :($(esc(nameof(d)))(θ...))
end
Is that something you're thinking of?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is that something you're thinking of?
No, I was thinking more about how to simplify the user implementations, it seems your suggestion would rather help to define the default fallbacks (which a user shouldn't have to care about). Since you mentioned that it's already possible to provide custom reconstructions I assumed it would be possible to call some VI-related method in the following way
vimethod(reconstruct, moreargs...; kwargs...)
where reconstruct
is a function that defines the reconstructions. One could then define default reconstructions in AdvancedVI (or some other package) and use them by default, i.e., something like
# I'm not sure if it is helpful to use a generated function here
function reconstruct(d::Distribution, theta...)
# maybe want to https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75 here
return typeof(d)(theta...)
end
vimethod(moreargs...; kwargs...) = vimethod(reconstruct, moreargs...; kwargs...)
Then a user/package could define
# fallback (might not always be needed)
myreconstruct(d::Distribution, theta...) = AdvancedVI.reconstruct(d, theta...)
function myreconstruct(d::SomeDistribution, theta)
# custom reconstruction
...
end
which would be used when calling vimethod(myreconstruct, moreargs...; kwargs...)
. In this approach one does not have to define the singletons but needs to define the fallback explicitly (however, it might not always be needed).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see what you mean. It still ends up with the annoying "so I have to define specific reconstruct
for each distribution?". I guess that is more of an issue with Distributions.jl rather than AVI. But you've convinced me; the above seems like it would be easier for packages to work with 👍
How about we merge this PR though, and I'll open an issue over at AVI referring to this discussion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I forgot to mention that it's necessary to use @generated
in update
because it might not have the same type as the original distribution, e.g. when using AD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's why I suggested to use https://github.com/SciML/DiffEqBase.jl/blob/8e1a627d10ec40a112460860cd13948cc0532c63/src/utils.jl#L73-L75
function AdvancedVI.vi( | ||
model::DynamicPPL.Model, | ||
alg::AdvancedVI.ADVI; | ||
optimizer = AdvancedVI.TruncatedADAGrad(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this one and the one below as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, not quite clear to me why we want to move this AVI. In my mind Turing depends on AVI for it's VI-functionality, and so Turing should be the one who overloads accordingly to ensure that the types used by Turing work nicely with AVI (which in this case is DynamicPPL.Model
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm less certain about this one. It's just that neither the function nor any of its arguments belong to Turing, which is a bit strange (and type piracy 🏴☠️)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But isn't this more an artifact of Turing now mostly being just a "shell" / "interface" for a lot of upstream packages? Though I agree type-piracy is bad, it seems like we're inevitably going to have to do a bit of it in Turing unless we want to construct a ton of wrapper structs with no real additional functionality.
Though I agree Requires.jl is an alternative, but it also means that packages that are really upstream for Turing now need to be the ones maintaining compatibility rather than the other way around. I guess that's okay since we're the ones maintaining all the packages, but I feel like the significant use of Requires.jl will come back and bite us at some point. It's already kind of difficult to track down all the upstream packages which require change. In particular it does kind of break the "separation of concerns" as maintainers now need a fairly good knowledge of all the packages involved.
I guess my point is that the type-piracy vs. Requires.jl isn't a clear-cut "always go with Requires.jl"? Maaaybe that's just me though 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I'm not convinced either, and in general I'm not a big fan of Requires.jl, it leads to all kind of issues with compatibilities, multiple optional dependencies, package load times, and testing. IMO the Julia package ecosystem still has to come up with a better solution for this general problem.
Codecov Report
@@ Coverage Diff @@
## master #1377 +/- ##
==========================================
- Coverage 69.67% 66.79% -2.89%
==========================================
Files 27 25 -2
Lines 1540 1605 +65
==========================================
- Hits 1073 1072 -1
- Misses 467 533 +66
Continue to review full report at Codecov.
|
varinfo = Turing.VarInfo(model) | ||
function Bijectors.bijector( | ||
model::DynamicPPL.Model, | ||
::Val{sym2ranges} = Val(false), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just noted that somehow I switched from keyword arguments to positional arguments (I guess since otherwise the name of the argument can't be omitted) - I guess we want to keep the keyword argument version?
Finally all tests are successful again 🎉 (Apart from the coverage...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really nice, many thanks @devmotion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with the PR as it is, and then I suggest we leave the changes to AVI we discussed for a later PR.
Awesome stuff @devmotion !
This PR removes some things that were left (but not used anymore) after #1362 and fixes a test error of Turing.Variational. Moreover, unnecessary package imports are removed. Switching from
using
toimport
points out more clearly to which packages used functions and types belong (Turing
or its submodules are not needed at all).