Skip to content

Factor out intermediate storage #72

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,16 @@ using StatsBase
using CUDA
using NNlib

# Filtering utilities
include("callbacks.jl")
include("containers.jl")
include("resamplers.jl")

## FILTERING BASE ##########################################################################

abstract type AbstractFilter <: AbstractSampler end
abstract type AbstractBatchFilter <: AbstractFilter end

"""
instantiate(model, alg, initial; kwargs...)

Create an intermediate storage object to store the proposed/filtered states at each step.
"""
function instantiate end

# Default method
function instantiate(model, alg, initial; kwargs...)
return Intermediate(initial, initial)
end

"""
initialise([rng,] model, alg; kwargs...)

Expand All @@ -37,9 +30,9 @@ Propose an initial state distribution.
function initialise end

"""
step([rng,] model, alg, iter, intermediate, observation; kwargs...)
step([rng,] model, alg, iter, state, observation; kwargs...)

Perform a combined predict and update call of the filtering on the intermediate storage.
Perform a combined predict and update call of the filtering on the state.
"""
function step end

Expand Down Expand Up @@ -70,24 +63,22 @@ function filter(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
observations::AbstractVector;
callback=nothing,
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
initial = initialise(rng, model, alg; kwargs...)
intermediate = instantiate(model, alg, initial; kwargs...)
isnothing(callback) || callback(model, alg, intermediate, observations; kwargs...)
state = initialise(rng, model, alg; kwargs...)
isnothing(callback) || callback(model, alg, state, observations, PostInit; kwargs...)

log_evidence = initialise_log_evidence(alg, model)

for t in eachindex(observations)
intermediate, ll_increment = step(
rng, model, alg, t, intermediate, observations[t]; callback, kwargs...
state, ll_increment = step(
rng, model, alg, t, state, observations[t]; callback, kwargs...
)
log_evidence += ll_increment
isnothing(callback) ||
callback(model, alg, t, intermediate, observations; kwargs...)
end

return intermediate.filtered, log_evidence
return state, log_evidence
end

function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel)
Expand All @@ -112,16 +103,20 @@ function step(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
iter::Integer,
intermediate,
state,
observation;
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
intermediate.proposed = predict(rng, model, alg, iter, intermediate.filtered; kwargs...)
intermediate.filtered, ll_increment = update(
model, alg, iter, intermediate.proposed, observation; kwargs...
)
state = predict(rng, model, alg, iter, state; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)

return intermediate, ll_increment
state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)

return state, ll_increment
end

## SMOOTHING BASE ##########################################################################
Expand All @@ -131,11 +126,6 @@ abstract type AbstractSmoother <: AbstractSampler end
# function smooth end
# function backward end

# Filtering utilities
include("callbacks.jl")
include("containers.jl")
include("resamplers.jl")

# Model types
include("models/linear_gaussian.jl")
include("models/discrete.jl")
Expand Down
59 changes: 27 additions & 32 deletions GeneralisedFilters/src/algorithms/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,31 @@ function step(
model::AbstractStateSpaceModel,
alg::AbstractParticleFilter,
iter::Integer,
intermediate,
state,
observation;
ref_state::Union{Nothing,AbstractVector}=nothing,
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
intermediate.proposed, intermediate.ancestors = resample(
rng, alg.resampler, intermediate.filtered
)
state = resample(rng, alg.resampler, state)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostResample; kwargs...)

intermediate.proposed = predict(
rng, model, alg, iter, intermediate.proposed; ref_state=ref_state, kwargs...
)
# TODO: this is quite inelegant and should be refactored
state = predict(rng, model, alg, iter, state; ref_state=ref_state, kwargs...)

# TODO: this is quite inelegant and should be refactored. It also might introduce bugs
# with callbacks that track the ancestry (and use PostResample)
if !isnothing(ref_state)
CUDA.@allowscalar intermediate.ancestors[1] = 1
CUDA.@allowscalar state.ancestors[1] = 1
end
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)

intermediate.filtered, ll_increment = update(
model, alg, iter, intermediate.proposed, observation; kwargs...
)
state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)

return intermediate, ll_increment
return state, ll_increment
end

struct BootstrapFilter{RS<:AbstractResampler} <: AbstractParticleFilter
Expand All @@ -46,13 +49,6 @@ function BootstrapFilter(
return BootstrapFilter{ESSResampler}(N, conditional_resampler)
end

function instantiate(
::StateSpaceModel{T}, filter::BootstrapFilter, initial; kwargs...
) where {T}
N = filter.N
return ParticleIntermediate(initial, initial, Vector{Int}(undef, N))
end

function initialise(
rng::AbstractRNG,
model::StateSpaceModel{T},
Expand All @@ -71,38 +67,37 @@ function predict(
model::StateSpaceModel,
filter::BootstrapFilter,
step::Integer,
filtered::ParticleDistribution;
state::ParticleDistribution;
ref_state::Union{Nothing,AbstractVector}=nothing,
kwargs...,
)
new_particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(filtered)
state.particles = map(
x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(state)
)
# Don't need to deep copy weights as filtered will be overwritten in the update step
proposed = ParticleDistribution(new_particles, filtered.log_weights)

return update_ref!(proposed, ref_state, step)
return update_ref!(state, ref_state, step)
end

function update(
model::StateSpaceModel{T},
filter::BootstrapFilter,
step::Integer,
proposed::ParticleDistribution,
state::ParticleDistribution,
observation;
kwargs...,
) where {T}
old_ll = logsumexp(state.log_weights)

log_increments = map(
x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...),
collect(proposed),
collect(state),
)

new_weights = proposed.log_weights + log_increments
filtered = ParticleDistribution(deepcopy(proposed.particles), new_weights)
state.log_weights += log_increments

ll_increment = logsumexp(filtered.log_weights) - logsumexp(proposed.log_weights)
ll_increment = logsumexp(state.log_weights) - old_ll

return filtered, ll_increment
return state, ll_increment
end

# Application of bootstrap filter to hierarchical models
Expand Down
19 changes: 13 additions & 6 deletions GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ struct KalmanSmoother <: AbstractSmoother end

const KS = KalmanSmoother()

struct StateCallback{T}
struct StateCallback{T} <: AbstractCallback
proposed_states::Vector{Gaussian{Vector{T},Matrix{T}}}
filtered_states::Vector{Gaussian{Vector{T},Matrix{T}}}
end
Expand All @@ -251,21 +251,28 @@ function StateCallback(N::Integer, T::Type)
end

function (callback::StateCallback)(
model::LinearGaussianStateSpaceModel, algo::KalmanFilter, states, obs; kwargs...
model::LinearGaussianStateSpaceModel,
algo::KalmanFilter,
iter::Integer,
state,
obs,
::PostPredictCallback;
kwargs...,
)
callback.proposed_states[iter] = deepcopy(state)
return nothing
end

function (callback::StateCallback)(
model::LinearGaussianStateSpaceModel,
algo::KalmanFilter,
iter::Integer,
states,
obs;
state,
obs,
::PostUpdateCallback;
kwargs...,
)
callback.proposed_states[iter] = states.proposed
callback.filtered_states[iter] = states.filtered
callback.filtered_states[iter] = deepcopy(state)
return nothing
end

Expand Down
Loading
Loading