Skip to content

VarInfo with custom accumulators #744

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
mhauru opened this issue Dec 10, 2024 · 17 comments
Open

VarInfo with custom accumulators #744

mhauru opened this issue Dec 10, 2024 · 17 comments
Assignees

Comments

@mhauru
Copy link
Member

mhauru commented Dec 10, 2024

@yebai and I have been discussing the idea of replacing the current VarInfo type with something more general as a wrapper around VarNamedVector/Metadata. The motivation is that VarInfo currently stores

  • logp. This is ostensibly innocuous and straight-forward, but actually this is sometimes the log prior, sometimes log likelihood, sometimes log joint. Also some samplers hijack this so that even when you're sampling from log joint a sampler may actually use the logp to store the log likelihood. This can cause mix-ups, I've had bugs in the new Gibbs implementation where I've thought I have the likelihood but I actually have the prior.
  • num_produce and order, which are only used by particle methods (see Decide the fate of VarInfo.num_produce #661 for previous discussion)
  • a VarNamedVector/Metadata with the variable values.

We would rather not have fields in VarInfo that are only used by specific samplers, and others that are used for different purposes at different times.

The solution we've been thinking of would be some sort of wrapper type, probably nested, that wraps a VarNamedVector and allows one to store any extra information needed. If a particle sampler needs num_produce and order, it'll implement it's own wrapper type, and other samplers that only need e.g. the logjoint would use a different wrapper type.

All of this could of course be stored in a context (because anything can be done with a context) but contexts are already too difficult to reason about, and should in my opinion only be used when a simpler tool won't cut it. Thus we would rather like an interface where somewhere in the tilde pipeline, maybe everywhere where we currently call acclogp, we would call some more generic store_custom_varinfo_data function, which each wrapper varinfo would then overload to store logprior/num_produce/whatever_they_want_to_store.

It's not obvious though whether such an interface would be powerful enough to implement all the things we want to use it for. So we should probably start by making a list of all the things we want to use it for. This would include at least

  • log prior/log likelihood/log joint. I would like to store these separately too, to avoid mixing them up.
  • num_produce/order
  • what else?

@yebai probably has more thoughts to share on this.

@mhauru mhauru changed the title Nested VarInfo VarInfo with custom accumulators Feb 20, 2025
@mhauru
Copy link
Member Author

mhauru commented Feb 20, 2025

Fleshing this idea out a bit. I'm imagining something like this:

abstract type AbstractAccumulator end

struct LogPrior{T<:AbstractFloat} <: AbstractAccumulator
    logp::T
end

struct LogLikelihood{T<:AbstractFloat} <: AbstractAccumulator
    logp::T
end

struct NumProduce{T<:Integer} <: AbstractAccumulator
    num::T
end

struct Orders{T<:Integer} <: AbstractAccumulator
    orders::Dict{VarName,T}
end

struct VarInfo{Tmeta,Accs<:NTuple{N,AbstractAccumulator} where {N}} <: AbstractVarInfo
    metadata::Tmeta
    accs::Accs
end

Functions like getlogp(vi), getloglikelihood(vi), and getnumproduce(vi) would check whether an accumulator of the necessary type is in vi.accs. If not, they would error with something like "This VarInfo is not tracking the num produce variable". If yes, they would return the relevant value. Not sure what to do if there are multiple accumulators of the same type, maybe just ban that possiblity in an inner constructor.

At the end of tilde_observe!! and tilde_assume!! there would be something like

vi = @set vi.accs = map(acc -> accumulate_observe!!(acc, vi, left, right), vi.accs)
# or
vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, vi, vn, right), vi.accs)

which is only allowed to modify acc, not the other arguments. The same way every context needs to define a method for tilde_observe and tilde_assume, every AbstractAccumulator would need to define methods for accumulate_observe!! and accumulate_assume!!, such as

accumulate_observe!!(acc::LogPrior, vi, left, right) = acc
accumulate_assume!!(acc::LogPrior, vi, vn, right) = LogPrior(acc.logp + logpdf(right, vi[vn]))

although we could provide the fallback defaults

accumulate_observe!!(acc::AbstractAccumulator, vi, left, right) = accumulate!!(acc, vi, left, right)
accumulate_assume!!(acc::AbstractAccumulator, vi, vn, right) = accumulate!!(acc, vi, vn, right)
accumulate!!(acc::AbstractAccumulator, vi, left, right) = acc

Separating logp into LogPrior and LogLikelihood is a somewhat orthogonal change, orthogonal to replacing vi.logp and vi.num_produce with vi.accs, but I don't see a reason to not accumulate them independently of each other and add them up in getlogp. setlogp!! and acclogp!! would no longer exist, you would always have to specify whether you are setting/adding to the log prior or log likelihood.

Benefits of a design like this:

  • We could get rid of DefaultContext, LikelihoodContext, and PriorContext, and hence the whole notion of a Leaf context, which would simplify context_implementations.jl a lot. We might also be able to get rid of e.g. PointwiseLogdensityContext, PriorExtractorContext, or ValuesAsInModelContext, replace them with accumulators as well.
  • If some particular use case of a model needs to keep track of something unusual, such as particle methods needing Order and NumProduce, that can be implemented outside of DPPL. We no longer need to hard code that stuff into VarInfo. The implementations are much lighter and less error-prone than a full-fledged context would be.

I do wonder whether the accs::NTuple{N,AbstractAccumulator} should actually not be a field of VarInfo, but rather a separate object that gets passed around in the tilde pipeline next to the varinfo. The two serve quite different purposes, and you could argue that a lot of functions of VarInfo should in fact reset the accumulators to maintain a consistent state: For instance, if you change the value of a variable, or subset a VarInfo, the accumulator values no longer match the values stored in the Metadata.

Ping @yebai and @penelopeysm for thoughts.

@penelopeysm
Copy link
Member

Fully in favour of the idea, just a couple of incredibly minor technical details, which maybe shouldn't even be in this comment:

I do wonder whether the accs::NTuple{N,AbstractAccumulator} should actually not be a field of VarInfo, but rather a separate object that gets passed around in the tilde pipeline next to the varinfo.

Possibly, but I think there is some entanglement of state anyway (exactly the instances that you described) so you'd end up with varinfo doing things to the accumulators in case it sees something. In which case we may as well put them together so that they have behaviour that's contained within one object

struct VarInfo{Tmeta,Accs<:NTuple{N,AbstractAccumulator} where {N}} <: AbstractVarInfo

Maybe we could also define (but using a generated function for type stability):

function accumulate_observe!!(accs::NTuple{N,AbstractAccumulator}, vi, left, right)
    for acc in accs
        accumulate_observe!!(acc, vi, left, right)
    end
end

and then at the end of tilde_{assume,observe}!! we can just write

vi = @set vi.accs = accumulate_{assume,observe}!!(vi.accs, vi, left, right)

we could provide the fallback defaults

I'd prefer not to, because it's not hard to implement the required methods, and i think it's easier to enforce interface if it errors loudly when the interface isn't obeyed

Functions like getlogp(vi), getloglikelihood(vi), and getnumproduce(vi) would check whether an accumulator of the necessary type is in vi.accs

Is it possible to do this also with a generated function?

@mhauru
Copy link
Member Author

mhauru commented Feb 21, 2025

Is it possible to do this also with a generated function?

Just to have the check be done at compile time and produce efficient code? Yeah, I think so. Could also try doing it the simpler way (maybe with recursion) and see if constant folding compiles it away, before bringing out the generated functions.

I'd prefer not to, because it's not hard to implement the required methods, and i think it's easier to enforce interface if it errors loudly when the interface isn't obeyed

Wouldn't it be easier, if I have an accumulator that only does something at observe, to only have to define that? Or if I have an accumulator that does the same thing for both observe and assume, then only have to define accumulate!!?

I'm not sure what you mean by making it easier to enforce the interface, in that the default implementations would guarantee that the interface is followed if you just subtype AbstractAccumulator, unless you explicitly go break it by e.g. defining accumulate_observe!!(::MyType, args...) that returns something that it shouldn't. With the default implementations in place the docs wouldn't say you have to implement accumulate_observe!! and accumulate_assume!!, but rather that you may implement one or both, or accumulate!!.

@penelopeysm
Copy link
Member

penelopeysm commented Feb 21, 2025

Wouldn't it be easier, if I have an accumulator that only does something at observe, to only have to define that?

It's just one extra line to define a no-op accumulate_assume.

Or if I have an accumulator that does the same thing for both observe and assume, then only have to define accumulate!!

You can still do that by defining your own accumulate!! (in fact, even better, give it a custom name, like accumulate_numproduce!!) and redirecting both acc_observe and acc_assume to that, for a total of two extra lines.

I'm not sure what you mean by making it easier to enforce the interface

I wasn't clear about this, apologies. You are right, having default implementations means that the interface will be satisfied even if the user doesn't try to. In that sense it actually makes it easier to satisfy the interface, in the sense that it is easier to have some function that returns something of the right type.

However, having default methods make it easier to inadvertently introduce bugs because a user forgot to define the behaviour they wanted and now there's a silent default method that does the wrong thing. In that sense, it makes it harder to have the right function that returns something of the right type and value.

Basically, I want to force the user to think and be explicit about what they're doing.

Also, having defaults makes it harder to find the correct method that a function call is dispatching to. see e.g. current tilde pipeline for an extreme example of this. (I understand why we do it, but it doesn't make it any simpler)

@penelopeysm
Copy link
Member

I'm not super hung up about it though if it's well documented (with a docs page explaining how to implement a new accumulator)

@torfjelde
Copy link
Member

A few though high-level initial thoughts / questions:

  1. Where in the tilde-pipeline would accumulate be called?
    1. Specifically, the purpose of PriorContext is that it short-circuits the tilde-pipeline to not even compute the log-likelihood. It's unclear to me how this would be done using the accumulator approach + if we could, it would lead to similar incompatibilities as we have now with varinfo, e.g. if I try to accumulate both Prior and Likelihood, both are short-circuiting the tilde-pipeline and nothing is done.
  2. Wrapping varinfo or related will be painful due to the number of functions you need to implement and test. Example: ThreadSafeVarInfo has bugged out sooo many ties times due to this.
  3. Only difference between varinfo and context right now is really just that:
    • varinfo is captured on return of the model, thus enabling making changes in an immutable way.
    • context is not captured, and so any changes must be made in-place.
  4. Does this change mean that contexts are not supposed to be mutated at all going forwards?
  5. I agree that the current contexts approach is quite general, and combining the contexts can be brittle. However, there's nothing stopping us from adding traits to the contexts (which has always been an ambition) which specifies where in the context stack it should be, e.g. the GibbsContext should be a parent to a leaf. So I'm not sure this accumulator approach, which AFAIK can sometimes be wrapped and sometimes not, solves this / improves this; such validation checks would still have to be implemented here as well. This also seems to be something you've also considered in #2424 👍

@mhauru
Copy link
Member Author

mhauru commented Feb 24, 2025

Where in the tilde-pipeline would accumulate be called?
Specifically, the purpose of PriorContext is that it short-circuits the tilde-pipeline to not even compute the log-likelihood. It's unclear to me how this would be done using the accumulator approach + if we could, it would lead to similar incompatibilities as we have now with varinfo, e.g. if I try to accumulate both Prior and Likelihood, both are short-circuiting the tilde-pipeline and nothing is done.

The log-likelihood/log-prior computation would happen in the accumulate_obssume!! function. ("Obssume" is a stand-in for either assume or observe.) Hence, if the varinfo e.g. carries a LogLikelihood accumulator but not a LogPrior accumulator, then in tilde_assume!! the there would only be a call to accumulate_assume!!(::LogLikelihood, args...) which is a no-op, and the log-prior would never be computed.

The place to call accumulate_obssume!! would be after sampling, if any, has been done, and the varinfo metadata has an up-to-date value for all variables. Pretty much the same place where acclogp_obssume!! is currently called. To avoid doing invlinking multiple times, I should probably modify the above signature for accumulate_obssume!! to take in the current value of the variable and logabsdetjac, so that e.g. tilde_assume!! would be something like

function tilde_assume!!(::EmptyContext, right, vn, vi)
    f = from_maybe_linked_internal_transform(vi, vn, right)
    x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
    vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, r, logjac, vi, vn, right), vi.accs)
    return r, vi
end

where EmptyContext would be the only leaf context, the one that marks the end of the context stack. (Could call it DefaultContext, just not to be mixed with the current DefaultContext which should really be called JointContext.)

Wrapping varinfo or related will be painful due to the number of functions you need to implement and test. Example: ThreadSafeVarInfo has bugged out sooo many ties times due to this.

Is you worry about the multitude of functions for getting data from various accumulators, like the equivalent of current getlogp and get_num_produce? I'm hoping VarInfo to really only have one function for "get stuff from the accumulators", call it getacc, for which one of the arguments is the type of the accumulator to get data from. Functions like getlogprior and getlogjoint would be sugared versions of getacc, where the sugar can be implemented on the level of AbstractVarInfo. Thus I don't see a reason to have many more functions to implement for e.g. ThreadSafeVarInfo, implementing getacc would be enough.

Only difference between varinfo and context right now is really just that: [...]

That's a good point, I hadn't thought about it that way. I would add though that there's also a semantic difference for someone reading the code, i.e. I expect the two to be used in very different ways.

Does this change mean that contexts are not supposed to be mutated at all going forwards?

I hadn't thought about this necessarily changing anything about how contexts can be used. There would just be fewer contexts that we would need to implement.

So I'm not sure this accumulator approach, which AFAIK can sometimes be wrapped and sometimes not

I don't understand what you mean here by "wrapped".

@torfjelde
Copy link
Member

"Obssume" is a stand-in for either assume or observe.

Amazing term:)

The place to call accumulateobssume!! would be after sampling, if any, has been done, and the varinfo metadata has an up-to-date value for all variables. Pretty much the same place where acclogpobssume!! is currently called.

The place to call accumulateobssume!! would be after sampling, if any, has been done, and the varinfo metadata has an up-to-date value for all variables. Pretty much the same place where acclogpobssume!! is currently called.

Gotcha, but then that does lead to different behavior from the current, say, LikelihoodContext, right? Since in the current LikelihoodContext, we don't even hit the logpdf call due to the short-circuit that occurs just before calling assume (the inner-most call of the tilde-pipeline). IIUC, here we would still hit assume as usual, causing a logpdf call, but then decide whether or not to include this later in the acc stage in tilde_assume!! (the entry-point for the tilde-pipeline). Is that correct?

Thus I don't see a reason to have many more functions to implement for e.g. ThreadSafeVarInfo, implementing getacc would be enough.

I think I misunderstood from the original description of this issue (which does reference the idea of using wrappers) 👍

Functions like getlogprior and getlogjoint would be sugared versions of getacc, where the sugar can be implemented on the level of AbstractVarInfo.

So these would just iterate over the accumulators and extract whatever they needed?

That's a good point, I hadn't thought about it that way. I would add though that there's also a semantic difference for someone reading the code, i.e. I expect the two to be used in very different ways.

I hadn't thought about this necessarily changing anything about how contexts can be used. There would just be fewer contexts that we would need to implement.

I get that 👍 Overall, I like the idea; I also think contexts are doing too many things and it makes it difficult to read. My main worry is just whether things will cleanly fit into the different semantics of accumulators vs. contexts. It's not 100% clear to me that separating the two will lead to more readable code unless there is an accompanying set of "rules" which decide where things go. It might get a bit annoying for both maintainers and contributors if every functionality that would currently be implemented as a context have to go through a discussion of "is this an accumulator or a context?".

Again, I do quite like this pattern you're suggesting, and am thinking it might be the way to go. But I think it might be worth going through the different contexts and make it clear which will be replaced by accumulators and which won't, and from that maybe derive a set of rules to decide where things go?

@mhauru
Copy link
Member Author

mhauru commented Feb 28, 2025

Gotcha, but then that does lead to different behavior from the current, say, LikelihoodContext, right? Since in the current LikelihoodContext, we don't even hit the logpdf call due to the short-circuit that occurs just before calling assume (the inner-most call of the tilde-pipeline). IIUC, here we would still hit assume as usual, causing a logpdf call, but then decide whether or not to include this later in the acc stage in tilde_assume!! (the entry-point for the tilde-pipeline). Is that correct?

I don't think so, because the call to logpdf would be within accumulate_obssume!!(::LogLikelihoodAccumulator, ...), and if you don't have a LogLikelihoodAccumulator in your VarInfo that method will never be hit. Coming back to the mock implementation of tilde_assume!! from above:

function tilde_assume!!(::EmptyContext, right, vn, vi)
    f = from_maybe_linked_internal_transform(vi, vn, right)
    x, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn))
    vi = @set vi.accs = map(acc -> accumulate_assume!!(acc, r, logjac, vi, vn, right), vi.accs)
    return r, vi
end

that just gets the value of the variable and transforms it, any logpdf computations would be the responsibility of things in vi.accs.

Note that with the new accumulators in place, LikelihoodContext and much of the stuff in context_implementations.jl wouldn't exist any more. Not sure exactly what the nested function calls would look like, but if assume continued to exist as a function call, it at least wouldn't do any logpdf computations.

So these would just iterate over the accumulators and extract whatever they needed?

Yes. With iteration hopefully happening at compile time.

It's not 100% clear to me that separating the two will lead to more readable code unless there is an accompanying set of "rules" which decide where things go. It might get a bit annoying for both maintainers and contributors if every functionality that would currently be implemented as a context have to go through a discussion of "is this an accumulator or a context?".

I would say anything that can be an accumulator should be an accumulator, because accumulators are much simpler and less powerful. If you can't do it with an accumulator, that's when you need a context. This is assuming you use accumulators according to their intended interface, which is to only define methods for accumulate_assume!! and accumulate_observe!!, and those methods should only ever mutate the accumulator and not any of the other arguments. If you find yourself dispatching tilde_assume on accumulator types you are doing it wrong and should use a context.

@mhauru mhauru self-assigned this Mar 3, 2025
@torfjelde
Copy link
Member

torfjelde commented Mar 3, 2025

I don't think so, because the call to logpdf would be within accumulate_obssume!!(::LogLikelihoodAccumulator), and if you don't have a LogLikelihoodAccumulator in your VarInfo that method will never be hit. Coming back to the mock implementation of tilde_assume!! from above:

Hmm, I'm a bit confused here.

Currently we have the following "pipeline": tilde_assume!! -> tilde_assume -> … -> tilde_assume -> assume.

assume is the method that actually calls logpdf, not tilde_assume!!.

Above you're referring to changing tilde_assume!! (which is the entrypoint); so are you saying we should completely remove tilde_assume and assume?

@mhauru
Copy link
Member Author

mhauru commented Mar 3, 2025

Above you're referring to changing tilde_assume!! (which is the entrypoint); so are you saying we should completely remove tilde_assume and assume?

Maybe, not sure. I would look at what remains of the pipeline after the current leaf contexts no longer exist and see how it can be simplified. The above code snippet was a "morally correct" mock implementation. But regardless of whether assume would still exist in some form or not, it would no longer call logpdf. The logpdf call would now happen in accumulate_obssume!! and nowhere else.

@penelopeysm
Copy link
Member

penelopeysm commented Mar 5, 2025

Does vi need to be passed as an argument to the accumulate functions, or could we leave it out?

For example, in this

accumulate_observe!!(acc::LogPrior, vi, left, right) = acc
accumulate_assume!!(acc::LogPrior, vi, vn, right) = LogPrior(acc.logp + logpdf(right, vi[vn]))

could we instead do

accumulate_observe!!(acc::LogPrior, val, left, right) = acc
accumulate_assume!!(acc::LogPrior, val, vn, right) = LogPrior(acc.logp + logpdf(right, val))

?

(I have no qualms with passing other info to accumulate_... like the log-jacobian. I'm just suggesting to not pass the whole varinfo.)

The reason why I'm suggesting this is because accessing vi gives it a lot of power. For example, it allows the accumulate_... functions to check for the presence and value of any other accumulators inside vi, and to modify its behaviour accordingly.

In fact, this would mean that the accumulator will have the same degree of power as the contexts that it's replacing. One would then have to worry about edge cases where AccumulatorFoo behaves in one way, but behaves differently when paired with AccumulatorBar — which was the very problem that we were trying to find a solution to 😄

Put another way, the reason why contexts are confusing is because tilde_assume(..., vi, ctx, ...) can do literally anything because it has access to both vi as well as the whole context stack (which is contained inside ctx). If we have accumulate_...(..., vi, acc, ...) then the same is true because we have access to vi as well as all the other accumulators (which are now contained inside vi). Effectively, it would be a disguised form of (1a) here.

If we don't allow it to access vi, it is then clearly mandated that each accumulator acts independently of the others, and each accumulation step only depends on left, right, and the sampled val. And I think this is a big part of the simplification that we're trying to accomplish.

Finally, notice that this separation would also very naturally lead to a rule for what should be an accumulator and what should be a context. If the special behaviour we're trying to accomplish is (1) independent of other special behaviour; and (2) depends only on left, right, and val, then it should be an accumulator. If not, then we need a context.

@mhauru
Copy link
Member Author

mhauru commented Mar 6, 2025

I agree @penelopeysm. I can't now remember why I put vi as an arg in my sketch, but it would be preferable to not have it. I think I need to try an implementation without it and see if a need arises, and have a think about what it means for this whole project if one does arise.

@penelopeysm
Copy link
Member

It struck me that we'll need to have a replacement for the @addlogprob! macro as well, which is quite widely used as far as I can tell.

@mhauru
Copy link
Member Author

mhauru commented Mar 17, 2025

Yeah, that's one of the stickiest things I have in mind for this. I think we should have @addlogprior! and @addloglikelihood!, and then @addlogprob! should call one of those two, maybe with a deprecation warning. But I'm not sure which of the two. Probably likelihood?

@torfjelde
Copy link
Member

Maybe, not sure. I would look at what remains of the pipeline after the current leaf contexts no longer exist and see how it can be simplified. The above code snippet was a "morally correct" mock implementation. But regardless of whether assume would still exist in some form or not, it would no longer call logpdf. The logpdf call would now happen in accumulate_obssume!! and nowhere else.

Hmm, I don't think I still fully get it 😕 I think it'll be clearer for me once there's an attempt / mock of implementation 👍

@yebai
Copy link
Member

yebai commented Apr 10, 2025

A minor suggestion for terminology: InferenceState / InferenceAccumulator might be more precise and readable than Accumulators. However, Accumulator is also a lovely name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants