You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The conversation that started in https://github.com/TuringLang/DynamicPPL.jl/pull/885/files/048178b7a8946d17fceace9b37c7e40846d50b51#r2069657078 resulted in some reflection on unflatten between me and @penelopeysm. The conclusion is that if an angel would read our code in unflatten, even after changes in #885, they would cover their eyes and cry silently, for that code is Wrong. It is Wrong because it couples the element type of the random variables given to us with the type of log probs. These are philosophically distinct: Random variables can take values in the set of sea birds, and the logpdf would still be a float. More practically, there's no reason why we shouldn't be able to have variables that are Float64s but accumulate log prob as Float32 (casting when necessary, when logpdf(dist, x) returns a Float64).
The reason why we do the Wrong thing is AD trace/dual number types. When the element type of our variables is Dual{Float64} our log probs need to become Dual{Float64} as well. We thought about this in various ways, tried to think of arguments to do something other than special case on Dual, but in the end failed. Dual just is special.
Thus, what we should do, is change the convert_eltype call in unflatten so that it is only done for a few special types, namely AD trace/dual number types. Every AD package that introduces its own Number or Real subtype that needs to be passed around like a float needs to define a method for this function, to mark it as one of the ones where log probs do need to be converted.
@penelopeysm, please expand if I didn't cover our full conclusion.
The text was updated successfully, but these errors were encountered:
Indeed, just to sum up, I think this is the behaviour we want. Suppose we have an existing VarInfo where the type of logp is Tlogp. (Right now it's a logp field; after #885 is merged to main, this type will be in the LogPrior and LogLikelihood accumulators.)
Then we call unflatten(vi, x::AbstractVector{Tx}) where Tx <: Real. The question is, in the VarInfo returned by this function, what type should its logp field have? We think it should have this behaviour:
Tlogp
Tx
resulting typeof(logp)
Float64
Float64
Float64
Float32
Float64
Float32
Float64
Dual{Float64}
Dual{Float64}
Float32
Dual{Float64}
Dual{Float32}
Doing this would require recognising that the Dual part of Tx must be used to wrap Tlogp, but the underlying 'base' Float32/Float64 type (or Int, or Bool, or SeaBird) shouldn't override the existing Tlogp precision.
The conversation that started in https://github.com/TuringLang/DynamicPPL.jl/pull/885/files/048178b7a8946d17fceace9b37c7e40846d50b51#r2069657078 resulted in some reflection on
unflatten
between me and @penelopeysm. The conclusion is that if an angel would read our code inunflatten
, even after changes in #885, they would cover their eyes and cry silently, for that code is Wrong. It is Wrong because it couples the element type of the random variables given to us with the type of log probs. These are philosophically distinct: Random variables can take values in the set of sea birds, and the logpdf would still be a float. More practically, there's no reason why we shouldn't be able to have variables that areFloat64
s but accumulate log prob asFloat32
(casting when necessary, whenlogpdf(dist, x)
returns aFloat64
).The reason why we do the Wrong thing is AD trace/dual number types. When the element type of our variables is
Dual{Float64}
our log probs need to becomeDual{Float64}
as well. We thought about this in various ways, tried to think of arguments to do something other than special case onDual
, but in the end failed.Dual
just is special.Thus, what we should do, is change the
convert_eltype
call inunflatten
so that it is only done for a few special types, namely AD trace/dual number types. Every AD package that introduces its ownNumber
orReal
subtype that needs to be passed around like a float needs to define a method for this function, to mark it as one of the ones where log probs do need to be converted.@penelopeysm, please expand if I didn't cover our full conclusion.
The text was updated successfully, but these errors were encountered: