|
| 1 | +using GeneralisedFilters |
| 2 | +using SSMProblems |
| 3 | +using LinearAlgebra |
| 4 | +using Random |
| 5 | + |
| 6 | +## TOY MODEL ############################################################################### |
| 7 | + |
| 8 | +# this is taken from an example in Kalman.jl |
| 9 | +function toy_model(θ::T) where {T<:Real} |
| 10 | + μ0 = T[1.0, 0.0] |
| 11 | + Σ0 = Diagonal(ones(T, 2)) |
| 12 | + |
| 13 | + A = T[0.8 θ/2; -0.1 0.8] |
| 14 | + Q = Diagonal(T[0.2, 1.0]) |
| 15 | + b = zeros(T, 2) |
| 16 | + |
| 17 | + H = Matrix{T}(I, 1, 2) |
| 18 | + R = Diagonal(T[0.2]) |
| 19 | + c = zeros(T, 1) |
| 20 | + |
| 21 | + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) |
| 22 | +end |
| 23 | + |
| 24 | +# data generation process with small sample |
| 25 | +rng = MersenneTwister(1234) |
| 26 | +true_model = toy_model(1.0) |
| 27 | +_, _, ys = sample(rng, true_model, 20) |
| 28 | + |
| 29 | +## RUN MOONCKAE TESTS ###################################################################### |
| 30 | + |
| 31 | +using DifferentiationInterface |
| 32 | +import Mooncake |
| 33 | +using DistributionsAD |
| 34 | + |
| 35 | +function build_objective(rng, θ, algo, data) |
| 36 | + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), algo, data) |
| 37 | + return -ll |
| 38 | +end |
| 39 | + |
| 40 | +# kalman filter likelihood testing (works, but is slow) |
| 41 | +logℓ1 = θ -> build_objective(rng, θ, KF(), ys) |
| 42 | +Mooncake.TestUtils.test_rule(rng, logℓ1, [0.7]; is_primitive=false, debug_mode=true) |
| 43 | + |
| 44 | +# bootstrap filter likelihood testing (shouldn't work) |
| 45 | +logℓ2 = θ -> build_objective(rng, θ, BF(512), ys) |
| 46 | +Mooncake.TestUtils.test_rule(rng, logℓ2, [0.7]; is_primitive=false, debug_mode=true) |
| 47 | + |
| 48 | +## FOR USE WITH DIFFERENTIATION INTERFACE ################################################## |
| 49 | + |
| 50 | +# data should be part of the objective, but be held constant by DifferentiationInterface |
| 51 | +logℓ3 = (θ, data) -> build_objective(rng, θ, KF(), data) |
| 52 | + |
| 53 | +# set the backend with default configuration |
| 54 | +backend = AutoMooncake(; config=nothing) |
| 55 | + |
| 56 | +# prepare the gradient for faster subsequent iteration |
| 57 | +grad_prep = prepare_gradient(logℓ3, backend, [0.7], Constant(ys)) |
| 58 | + |
| 59 | +# evaluate gradients and iterate to show proof of concept |
| 60 | +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.7], Constant(ys)) |
| 61 | +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.8], Constant(ys)) |
| 62 | +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.9], Constant(ys)) |
| 63 | +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [1.0], Constant(ys)) |
0 commit comments