Skip to content

Commit 3590800

Browse files
authored
Merge pull request #1232 from TuringLang/csp/hessian-fix
Fix hessian bug
2 parents 1c36342 + 0568fda commit 3590800

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/inference/Inference.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ function get_matching_type(
566566
end
567567
function get_matching_type(
568568
spl::AbstractSampler,
569-
vi,
569+
vi,
570570
::Type{<:AbstractFloat},
571571
)
572572
return floatof(eltype(vi, spl))

test/core/ad.jl

+38
Original file line numberDiff line numberDiff line change
@@ -302,4 +302,42 @@ _to_cov(B) = B * B' + Matrix(I, size(B)...)
302302
Turing.setadbackend(:zygote)
303303
sample(invwishart(), HMC(0.01, 1), 1000);
304304
end
305+
@testset "Hessian test" begin
306+
@model function tst(x, ::Type{TV}=Vector{Float64}) where {TV}
307+
params = TV(undef, 2)
308+
@. params ~ Normal(0, 1)
309+
310+
x ~ MvNormal(params, 1)
311+
end
312+
313+
function make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext)
314+
# setup
315+
varinfo_init = Turing.VarInfo(model)
316+
spl = DynamicPPL.SampleFromPrior()
317+
DynamicPPL.link!(varinfo_init, spl)
318+
319+
function logπ(z; unlinked = false)
320+
varinfo = DynamicPPL.VarInfo(varinfo_init, spl, z)
321+
322+
unlinked && DynamicPPL.invlink!(varinfo_init, spl)
323+
model(varinfo, spl, ctx)
324+
unlinked && DynamicPPL.link!(varinfo_init, spl)
325+
326+
return -DynamicPPL.getlogp(varinfo)
327+
end
328+
329+
return logπ
330+
end
331+
332+
data = [0.5, -0.5]
333+
model = tst(data)
334+
335+
likelihood = make_logjoint(model, DynamicPPL.LikelihoodContext())
336+
target(x) = likelihood(x, unlinked=true)
337+
338+
H_f = ForwardDiff.hessian(target, zeros(2))
339+
H_r = ReverseDiff.hessian(target, zeros(2))
340+
@test H_f == [1.0 0.0; 0.0 1.0]
341+
@test H_f == H_r
342+
end
305343
end

0 commit comments

Comments
 (0)