Skip to content

Logp type should be separate from variable type - this needs AD extensions #906

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mhauru opened this issue May 2, 2025 · 1 comment

Comments

@mhauru
Copy link
Member

mhauru commented May 2, 2025

The conversation that started in https://github.com/TuringLang/DynamicPPL.jl/pull/885/files/048178b7a8946d17fceace9b37c7e40846d50b51#r2069657078 resulted in some reflection on unflatten between me and @penelopeysm. The conclusion is that if an angel would read our code in unflatten, even after changes in #885, they would cover their eyes and cry silently, for that code is Wrong. It is Wrong because it couples the element type of the random variables given to us with the type of log probs. These are philosophically distinct: Random variables can take values in the set of sea birds, and the logpdf would still be a float. More practically, there's no reason why we shouldn't be able to have variables that are Float64s but accumulate log prob as Float32 (casting when necessary, when logpdf(dist, x) returns a Float64).

The reason why we do the Wrong thing is AD trace/dual number types. When the element type of our variables is Dual{Float64} our log probs need to become Dual{Float64} as well. We thought about this in various ways, tried to think of arguments to do something other than special case on Dual, but in the end failed. Dual just is special.

Thus, what we should do, is change the convert_eltype call in unflatten so that it is only done for a few special types, namely AD trace/dual number types. Every AD package that introduces its own Number or Real subtype that needs to be passed around like a float needs to define a method for this function, to mark it as one of the ones where log probs do need to be converted.

@penelopeysm, please expand if I didn't cover our full conclusion.

@penelopeysm
Copy link
Member

penelopeysm commented May 2, 2025

Indeed, just to sum up, I think this is the behaviour we want. Suppose we have an existing VarInfo where the type of logp is Tlogp. (Right now it's a logp field; after #885 is merged to main, this type will be in the LogPrior and LogLikelihood accumulators.)

Then we call unflatten(vi, x::AbstractVector{Tx}) where Tx <: Real. The question is, in the VarInfo returned by this function, what type should its logp field have? We think it should have this behaviour:

Tlogp Tx resulting typeof(logp)
Float64 Float64 Float64
Float32 Float64 Float32
Float64 Dual{Float64} Dual{Float64}
Float32 Dual{Float64} Dual{Float32}

Doing this would require recognising that the Dual part of Tx must be used to wrap Tlogp, but the underlying 'base' Float32/Float64 type (or Int, or Bool, or SeaBird) shouldn't override the existing Tlogp precision.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants