Skip to content

Commit f2f6665

Browse files
authored
Provide AD gradient for MLE/MAP (#1369)
* Use in-place gradients * Fixing tests, tidying things up * Increment patch version * Change version to 0.14.0, address comments * Remove hack to fix 2nd order optimizers * Remove redundant FG function * Add contexts to gradient_logp for Zygote and ReverseDiff * One day I'll fix all the files at once
1 parent 43e2f20 commit f2f6665

File tree

6 files changed

+56
-16
lines changed

6 files changed

+56
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.13.0"
3+
version = "0.14.0"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/core/ad.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ ADBackend(::Val) = error("The requested AD backend is not available. Make sure t
6060
Find the autodifferentiation backend of the algorithm `alg`.
6161
"""
6262
getADbackend(spl::Sampler) = getADbackend(spl.alg)
63+
getADbackend(spl::SampleFromPrior) = ADBackend()()
6364

6465
"""
6566
gradient_logp(
@@ -77,9 +78,10 @@ function gradient_logp(
7778
θ::AbstractVector{<:Real},
7879
vi::VarInfo,
7980
model::Model,
80-
sampler::Sampler
81+
sampler::AbstractSampler,
82+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
8183
)
82-
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler)
84+
return gradient_logp(getADbackend(sampler), θ, vi, model, sampler, ctx)
8385
end
8486

8587
"""
@@ -100,12 +102,13 @@ function gradient_logp(
100102
vi::VarInfo,
101103
model::Model,
102104
sampler::AbstractSampler=SampleFromPrior(),
105+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
103106
)
104107
# Define function to compute log joint.
105108
logp_old = getlogp(vi)
106109
function f(θ)
107110
new_vi = VarInfo(vi, sampler, θ)
108-
model(new_vi, sampler)
111+
model(new_vi, sampler, ctx)
109112
logp = getlogp(new_vi)
110113
setlogp!(vi, ForwardDiff.value(logp))
111114
return logp
@@ -127,13 +130,14 @@ function gradient_logp(
127130
vi::VarInfo,
128131
model::Model,
129132
sampler::AbstractSampler = SampleFromPrior(),
133+
ctx::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
130134
)
131135
T = typeof(getlogp(vi))
132136

133137
# Specify objective function.
134138
function f(θ)
135139
new_vi = VarInfo(vi, sampler, θ)
136-
model(new_vi, sampler)
140+
model(new_vi, sampler, ctx)
137141
return getlogp(new_vi)
138142
end
139143

src/core/compat/reversediff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ function gradient_logp(
2020
vi::VarInfo,
2121
model::Model,
2222
sampler::AbstractSampler = SampleFromPrior(),
23+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
2324
)
2425
T = typeof(getlogp(vi))
2526

@@ -57,6 +58,7 @@ end
5758
vi::VarInfo,
5859
model::Model,
5960
sampler::AbstractSampler = SampleFromPrior(),
61+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
6062
)
6163
T = typeof(getlogp(vi))
6264

src/core/compat/zygote.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ function gradient_logp(
1010
vi::VarInfo,
1111
model::Model,
1212
sampler::AbstractSampler = SampleFromPrior(),
13+
context::DynamicPPL.AbstractContext = DynamicPPL.DefaultContext()
1314
)
1415
T = typeof(getlogp(vi))
1516

src/modes/ModeEstimation.jl

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,44 @@ function (f::OptimLogDensity)(z)
147147
return -DynamicPPL.getlogp(varinfo)
148148
end
149149

150+
function (f::OptimLogDensity)(F, G, H, z)
151+
# Throw an error if a second order method was used.
152+
if H !== nothing
153+
error("Second order optimization is not yet supported.")
154+
end
155+
156+
spl = DynamicPPL.SampleFromPrior()
157+
158+
if G !== nothing
159+
# Calculate log joint and the gradient
160+
l, g = gradient_logp(
161+
z,
162+
DynamicPPL.VarInfo(f.vi, spl, z),
163+
f.model,
164+
spl,
165+
f.context
166+
)
167+
168+
# Use the negative gradient because we are minimizing.
169+
G[:] = -g
170+
171+
# If F is something, return that since we already have the
172+
# log joint.
173+
if F !== nothing
174+
F = -l
175+
return F
176+
end
177+
end
178+
179+
# No gradient necessary, just return the log joint.
180+
if F !== nothing
181+
F = f(z)
182+
return F
183+
end
184+
185+
return nothing
186+
end
187+
150188
"""
151189
ModeResult{
152190
V<:NamedArrays.NamedArray,
@@ -378,9 +416,8 @@ function _optimize(
378416
link!(f.vi, spl)
379417
init_vals = f.vi[spl]
380418

381-
382419
# Optimize!
383-
M = Optim.optimize(f, init_vals, optimizer, options, args...; kwargs...)
420+
M = Optim.optimize(Optim.only_fgh!(f), init_vals, optimizer, options, args...; kwargs...)
384421

385422
# Warn the user if the optimization did not converge.
386423
if !Optim.converged(M)

test/modes/ModeEstimation.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,17 @@ include(dir*"/test/test_utils/AllUtils.jl")
1313
@testset "ModeEstimation.jl" begin
1414
@testset "MLE" begin
1515
Random.seed!(222)
16-
true_value = [0.0625031, 1.75]
16+
true_value = [0.0625, 1.75]
1717

1818
m1 = optimize(gdemo_default, MLE())
1919
m2 = optimize(gdemo_default, MLE(), NelderMead())
20-
m3 = optimize(gdemo_default, MLE(), Newton())
21-
m4 = optimize(gdemo_default, MLE(), true_value, Newton())
22-
m5 = optimize(gdemo_default, MLE(), true_value)
20+
m3 = optimize(gdemo_default, MLE(), true_value, LBFGS())
21+
m4 = optimize(gdemo_default, MLE(), true_value)
2322

2423
@test all(isapprox.(m1.values.array - true_value, 0.0, atol=0.01))
2524
@test all(isapprox.(m2.values.array - true_value, 0.0, atol=0.01))
2625
@test all(isapprox.(m3.values.array - true_value, 0.0, atol=0.01))
2726
@test all(isapprox.(m4.values.array - true_value, 0.0, atol=0.01))
28-
@test all(isapprox.(m5.values.array - true_value, 0.0, atol=0.01))
2927
end
3028

3129
@testset "MAP" begin
@@ -34,15 +32,13 @@ include(dir*"/test/test_utils/AllUtils.jl")
3432

3533
m1 = optimize(gdemo_default, MAP())
3634
m2 = optimize(gdemo_default, MAP(), NelderMead())
37-
m3 = optimize(gdemo_default, MAP(), Newton())
38-
m4 = optimize(gdemo_default, MAP(), true_value, Newton())
39-
m5 = optimize(gdemo_default, MAP(), true_value)
35+
m3 = optimize(gdemo_default, MAP(), true_value, LBFGS())
36+
m4 = optimize(gdemo_default, MAP(), true_value)
4037

4138
@test all(isapprox.(m1.values.array - true_value, 0.0, atol=0.01))
4239
@test all(isapprox.(m2.values.array - true_value, 0.0, atol=0.01))
4340
@test all(isapprox.(m3.values.array - true_value, 0.0, atol=0.01))
4441
@test all(isapprox.(m4.values.array - true_value, 0.0, atol=0.01))
45-
@test all(isapprox.(m5.values.array - true_value, 0.0, atol=0.01))
4642
end
4743

4844
@testset "StatsBase integration" begin

0 commit comments

Comments
 (0)