diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 9510d9685..b3db2982d 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -566,7 +566,7 @@ function get_matching_type( end function get_matching_type( spl::AbstractSampler, - vi, + vi, ::Type{<:AbstractFloat}, ) return floatof(eltype(vi, spl)) diff --git a/test/core/ad.jl b/test/core/ad.jl index 77cb06e0c..c2a73cee8 100644 --- a/test/core/ad.jl +++ b/test/core/ad.jl @@ -302,4 +302,42 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...) Turing.setadbackend(:zygote) sample(invwishart(), HMC(0.01, 1), 1000); end + @testset "Hessian test" begin + @model function tst(x, ::Type{TV}=Vector{Float64}) where {TV} + params = TV(undef, 2) + @. params ~ Normal(0, 1) + + x ~ MvNormal(params, 1) + end + + function make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext) + # setup + varinfo_init = Turing.VarInfo(model) + spl = DynamicPPL.SampleFromPrior() + DynamicPPL.link!(varinfo_init, spl) + + function logπ(z; unlinked = false) + varinfo = DynamicPPL.VarInfo(varinfo_init, spl, z) + + unlinked && DynamicPPL.invlink!(varinfo_init, spl) + model(varinfo, spl, ctx) + unlinked && DynamicPPL.link!(varinfo_init, spl) + + return -DynamicPPL.getlogp(varinfo) + end + + return logπ + end + + data = [0.5, -0.5] + model = tst(data) + + likelihood = make_logjoint(model, DynamicPPL.LikelihoodContext()) + target(x) = likelihood(x, unlinked=true) + + H_f = ForwardDiff.hessian(target, zeros(2)) + H_r = ReverseDiff.hessian(target, zeros(2)) + @test H_f == [1.0 0.0; 0.0 1.0] + @test H_f == H_r + end end