-
Notifications
You must be signed in to change notification settings - Fork 35
Accumulators, stage 1 #885
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
061acbe
1496868
324e623
bb59885
fc32398
cc5e581
b9c368b
1b8f555
ae9b1cd
4fb0bf4
e410f47
97788bd
7fe03ec
5ba3530
d49f7be
28bbf1c
a0ed665
be27636
e6453fe
c59400d
47033ce
8b841c9
3ee3989
13163f2
37dd6dd
d7013b6
40d4caa
ff5f2cb
c68f1bb
13da08a
d52feec
221e797
1dbcb2c
e1b70e0
3f195e5
68b974a
2b405d9
00cd304
6d1048d
4fef20f
905b874
557954a
6f702c9
f748775
5f4a532
31967fd
ad2f564
10b4f2f
d2b670d
8241d12
7b7a3e2
0b08237
2a4b874
c1e90f7
cb1c6c6
00ef0cf
14f4788
c4ee4ec
7ad9450
c5e2a6b
fb09acc
e324c9b
6b7b9f8
048178b
6437801
bf95169
efc7c53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -91,36 +91,70 @@ function transformation end | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Accumulation of log-probabilities. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
getlogp(vi::AbstractVarInfo) | ||||||||||||||||||||||||||||||||||
getlogjoint(vi::AbstractVarInfo) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Return the log of the joint probability of the observed data and parameters sampled in | ||||||||||||||||||||||||||||||||||
`vi`. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
function getlogp end | ||||||||||||||||||||||||||||||||||
getlogjoint(vi::AbstractVarInfo) = getlogprior(vi) + getloglikelihood(vi) | ||||||||||||||||||||||||||||||||||
getlogp(vi::AbstractVarInfo) = getlogjoint(vi) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function setaccs!! end | ||||||||||||||||||||||||||||||||||
function getaccs end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
getlogprior(vi::AbstractVarInfo) = getacc(vi, LogPrior).logp | ||||||||||||||||||||||||||||||||||
getloglikelihood(vi::AbstractVarInfo) = getacc(vi, LogLikelihood).logp | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function setacc!!(vi::AbstractVarInfo, acc::AbstractAccumulator) | ||||||||||||||||||||||||||||||||||
return setaccs!!(vi, setacc!!(getaccs(vi), acc)) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
setlogprior!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogPrior(logp)) | ||||||||||||||||||||||||||||||||||
setloglikelihood!!(vi::AbstractVarInfo, logp) = setacc!!(vi, LogLikelihood(logp)) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
setlogp!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Set the log of the joint probability of the observed data and parameters sampled in | ||||||||||||||||||||||||||||||||||
`vi` to `logp`, mutating if it makes sense. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
function setlogp!! end | ||||||||||||||||||||||||||||||||||
function setlogp!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
vi = setlogprior!!(vi, zero(logp)) | ||||||||||||||||||||||||||||||||||
vi = setloglikelihood!!(vi, logp) | ||||||||||||||||||||||||||||||||||
return vi | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about this the other day and thought I may as well post now. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My hope was that we could deprecate them but provide the same functionality through the new functions, like above. It's a good question as to whether there are edge cases where they do not provide the same functionality. I think this is helped by the fact that PriorContext and LikelihoodContext won't exist, and hence one can't be running code where the expectation would be that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something like this is a case where setlogp is ill-defined: DynamicPPL.jl/src/test_utils/varinfo.jl Lines 47 to 62 in c7bdc3f
The logp here contains terms from both prior and likelihood, but after calling setlogp the prior would always be 0, which is inconsistent with the varinfo. Of course, we can fix this on our end - we would get and set logprior and loglikelihood manually, and we can grep the codebase to make sure that there are no other ill-defined calls to setlogp. We can't guarantee that other people will be similarly careful, though (and us or anyone being careful also doesn't guarantee that everything will be fixed correctly). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While looking for other uses of setlogp, I encountered this:
(For the record, I'd be quite happy with making all of these changes!) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It is inconsistent, but as long as the user only uses
Yeah, this sort of stuff will come up (and is coming up) in multiple places. Anything that explicitly uses PriorContext or LikelihoodContext would need to be changed to use LogPrior and LogLikelihood accumulators instead. I'm currently doing this for |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function getacc(vi::AbstractVarInfo, ::Type{AccType}) where {AccType} | ||||||||||||||||||||||||||||||||||
return getacc(getaccs(vi), AccType) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function accumulate_assume!!(vi::AbstractVarInfo, r, logjac, vn, right) | ||||||||||||||||||||||||||||||||||
return setaccs!!(vi, accumulate_assume!!(getaccs(vi), r, logjac, vn, right)) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function accumulate_observe!!(vi::AbstractVarInfo, left, right) | ||||||||||||||||||||||||||||||||||
return setaccs!!(vi, accumulate_observe!!(getaccs(vi), left, right)) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function acc!!(vi::AbstractVarInfo, ::Type{AccType}, args...) where {AccType} | ||||||||||||||||||||||||||||||||||
return setaccs!!(vi, acc!!(getaccs(vi), AccType, args...)) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function acclogprior!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
return acc!!(vi, LogPrior, logp) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
function accloglikelihood!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
return acc!!(vi, LogLikelihood, logp) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
acclogp!!([context::AbstractContext, ]vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
acclogp!!(vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Add `logp` to the value of the log of the joint probability of the observed data and | ||||||||||||||||||||||||||||||||||
parameters sampled in `vi`, mutating if it makes sense. | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
function acclogp!!(context::AbstractContext, vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
return acclogp!!(NodeTrait(context), context, vi, logp) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
function acclogp!!(::IsLeaf, context::AbstractContext, vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
return acclogp!!(vi, logp) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
function acclogp!!(::IsParent, context::AbstractContext, vi::AbstractVarInfo, logp) | ||||||||||||||||||||||||||||||||||
return acclogp!!(childcontext(context), vi, logp) | ||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||
acclogp!!(vi::AbstractVarInfo, logp) = accloglikelihood!!(vi, logp) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
resetlogp!!(vi::AbstractVarInfo) | ||||||||||||||||||||||||||||||||||
|
@@ -247,11 +281,11 @@ julia> values_as(SimpleVarInfo(data), Vector) | |||||||||||||||||||||||||||||||||
2.0 | ||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
`TypedVarInfo`: | ||||||||||||||||||||||||||||||||||
`VarInfo` with `NamedTuple` of `Metadata`: | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
```jldoctest | ||||||||||||||||||||||||||||||||||
julia> # Just use an example model to construct the `VarInfo` because we're lazy. | ||||||||||||||||||||||||||||||||||
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); | ||||||||||||||||||||||||||||||||||
vi = DynamicPPL.typed_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -273,11 +307,11 @@ julia> values_as(vi, Vector) | |||||||||||||||||||||||||||||||||
2.0 | ||||||||||||||||||||||||||||||||||
``` | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
`UntypedVarInfo`: | ||||||||||||||||||||||||||||||||||
`VarInfo` with `Metadata`: | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
```jldoctest | ||||||||||||||||||||||||||||||||||
julia> # Just use an example model to construct the `VarInfo` because we're lazy. | ||||||||||||||||||||||||||||||||||
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi); | ||||||||||||||||||||||||||||||||||
vi = DynamicPPL.untyped_varinfo(DynamicPPL.TestUtils.demo_assume_dot_observe()); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0; | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
@@ -725,7 +759,7 @@ end | |||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Legacy code that is currently overloaded for the sake of simplicity. | ||||||||||||||||||||||||||||||||||
# TODO: Remove when possible. | ||||||||||||||||||||||||||||||||||
increment_num_produce!(::AbstractVarInfo) = nothing | ||||||||||||||||||||||||||||||||||
increment_num_produce!!(::AbstractVarInfo) = nothing | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
from_internal_transform(varinfo::AbstractVarInfo, vn::VarName[, dist]) | ||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.