-
Notifications
You must be signed in to change notification settings - Fork 226
Add MLE/MAP functionality #1230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 36 commits
5f24ae3
f8edbab
7f6a03b
d682d9b
3611719
43b1cc8
40c80db
23a36bc
2ba1c81
fd4a4de
dc0dd1c
dca46e8
e616eb8
f9fa51f
3fc51d9
3e8c64c
b1a4861
4373c88
0a769b7
e660fb5
4fc1d73
cb56cd2
fab5b2d
1f2e8cb
d97f2dc
16bbdf6
403ddcd
58fd2cf
3eb5251
a8b9869
0809329
065bc18
e652e00
1e5aea8
199c00c
bd0c1bb
05ae0d4
345776c
a2b159c
856f282
8cd87aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -58,6 +58,11 @@ using .Variational | |
end | ||
end | ||
|
||
@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin | ||
include("modes/ModeEstimation.jl") | ||
export MAP, MLE, optimize | ||
end | ||
|
||
########### | ||
# Exports # | ||
########### | ||
|
@@ -87,7 +92,7 @@ export @model, # modelling | |
CSMC, | ||
PG, | ||
|
||
vi, # variational inference | ||
vi, # variational inference | ||
ADVI, | ||
|
||
sample, # inference | ||
|
@@ -110,5 +115,8 @@ export @model, # modelling | |
LogPoisson, | ||
NamedDist, | ||
filldist, | ||
arraydist | ||
arraydist, | ||
|
||
MLE, # mode estimation tools | ||
MAP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move these exports to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You forgot to remove the exports here.
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
using ..Turing | ||
using ..Bijectors | ||
using LinearAlgebra | ||
|
||
import ..AbstractMCMC: AbstractSampler | ||
import ..DynamicPPL | ||
import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName, | ||
_getindex, getsym, getfield, settrans!, setorder!, | ||
get_and_set_val!, istrans, tilde, dot_tilde | ||
import Optim | ||
import Optim: optimize | ||
import NamedArrays | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
import ..ForwardDiff | ||
import StatsBase | ||
import Printf | ||
|
||
struct MLE end | ||
struct MAP end | ||
|
||
""" | ||
OptimizationContext{C<:AbstractContext} <: AbstractContext | ||
|
||
The `OptimizationContext` transforms variables to their constrained space, but | ||
does not use the density with respect to the transformation. This context is | ||
intended to allow an optimizer to sample in R^n freely. | ||
""" | ||
struct OptimizationContext{C<:AbstractContext} <: AbstractContext | ||
context::C | ||
end | ||
|
||
# assume | ||
function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi) | ||
r = vi[vn] | ||
return r, 0 | ||
end | ||
|
||
function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) | ||
r = vi[vn] | ||
return r, Distributions.logpdf(dist, r) | ||
end | ||
|
||
# observe | ||
function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) | ||
return 0 | ||
end | ||
|
||
function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi) | ||
return Distributions.logpdf(dist, value) | ||
end | ||
|
||
# dot assume | ||
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi) | ||
vns, dist = get_vns_and_dist(right, left, vn) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r = getval(vi, vns) | ||
return r, 0 | ||
end | ||
|
||
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi) | ||
vns, dist = get_vns_and_dist(right, left, vn) | ||
r = getval(vi, vns) | ||
return r, loglikelihood(dist, r) | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end | ||
|
||
# dot observe | ||
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) | ||
return 0 | ||
end | ||
|
||
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) | ||
return 0 | ||
end | ||
|
||
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi) | ||
vns, dist = get_vns_and_dist(right, left, vn) | ||
r = getval(vi, vns) | ||
return loglikelihood(dist, r) | ||
end | ||
|
||
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi) | ||
return sum(Distributions.logpdf.(dists, value)) | ||
end | ||
|
||
function getval( | ||
vi, | ||
vns::AbstractVector{<:VarName}, | ||
) | ||
r = vi[vns] | ||
return r | ||
end | ||
|
||
function getval( | ||
vi, | ||
vns::AbstractArray{<:VarName}, | ||
) | ||
r = reshape(vi[vec(vns)], size(vns)) | ||
return r | ||
end | ||
|
||
""" | ||
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} | ||
|
||
A struct that stores the log density function of a `DynamicPPL` model. | ||
""" | ||
struct OptimLogDensity{M<:Model,C<:AbstractContext,V<:VarInfo} | ||
"A `DynamicPPL.Model` constructed either with the `@model` macro or manually." | ||
model::M | ||
"A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE." | ||
context::C | ||
"A `DynamicPPL.VarInfo` struct that will be used to update model parameters." | ||
vi::V | ||
end | ||
|
||
""" | ||
OptimLogDensity(model::Model, context::AbstractContext) | ||
|
||
Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`. | ||
""" | ||
function OptimLogDensity(model::Model, context::AbstractContext) | ||
init = VarInfo(model) | ||
DynamicPPL.link!(init, DynamicPPL.SampleFromPrior()) | ||
return OptimLogDensity(model, context, init) | ||
end | ||
|
||
""" | ||
(f::OptimLogDensity)(z) | ||
|
||
Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`) | ||
at the array `z`. | ||
""" | ||
function (f::OptimLogDensity)(z) | ||
spl = DynamicPPL.SampleFromPrior() | ||
|
||
varinfo = DynamicPPL.VarInfo(f.vi, spl, z) | ||
f.model(varinfo, spl, f.context) | ||
return -DynamicPPL.getlogp(varinfo) | ||
end | ||
|
||
""" | ||
ModeResult{ | ||
V<:NamedArrays.NamedArray, | ||
M<:NamedArrays.NamedArray, | ||
O<:Optim.MultivariateOptimizationResults, | ||
S<:NamedArrays.NamedArray | ||
} | ||
|
||
A wrapper struct to store various results from a MAP or MLE estimation. | ||
""" | ||
struct ModeResult{ | ||
V<:NamedArrays.NamedArray, | ||
O<:Optim.MultivariateOptimizationResults, | ||
M<:OptimLogDensity | ||
} <: StatsBase.StatisticalModel | ||
"A vector with the resulting point estimates." | ||
values :: V | ||
"The stored Optim.jl results." | ||
optim_result :: O | ||
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run." | ||
lp :: Float64 | ||
"The evaluation function used to calculate the output." | ||
f :: M | ||
end | ||
############################# | ||
# Various StatsBase methods # | ||
############################# | ||
|
||
|
||
|
||
function Base.show(io::IO, ::MIME"text/plain", m::ModeResult) | ||
print(io, "ModeResult with minimized lp of ") | ||
Printf.@printf(io, "%.2f", m.lp) | ||
println(io) | ||
show(io, m.values) | ||
end | ||
|
||
function Base.show(io::IO, m::ModeResult) | ||
show(io, m.values.array) | ||
end | ||
|
||
function StatsBase.coeftable(m::ModeResult) | ||
# Get columns for coeftable. | ||
terms = StatsBase.coefnames(m) | ||
estimates = m.values.array[:,1] | ||
stderrors = StatsBase.stderror(m) | ||
tstats = estimates ./ stderrors | ||
|
||
StatsBase.CoefTable([estimates, stderrors, tstats], ["estimate", "stderror", "tstat"], terms) | ||
end | ||
|
||
function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...) | ||
# Calculate Hessian and information matrix. | ||
varnames = StatsBase.coefnames(m) | ||
mohamed82008 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
info = inv(hessian_function(m.f, m.values.array[:, 1])) | ||
return NamedArrays.NamedArray(info, (varnames, varnames)) | ||
end | ||
|
||
StatsBase.coef(m::ModeResult) = m.values | ||
StatsBase.coefnames(m::ModeResult) = names(m.values)[1] | ||
StatsBase.params(m::ModeResult) = StatsBase.coefnames(m) | ||
StatsBase.vcov(m::ModeResult) = StatsBase.informationmatrix(m) | ||
StatsBase.loglikelihood(m::ModeResult) = m.lp | ||
|
||
#################### | ||
# Optim.jl methods # | ||
#################### | ||
|
||
""" | ||
Optim.optimize(model::Model, ::MLE, args...; kwargs...) | ||
|
||
Compute a maximum likelihood estimate of the `model`. | ||
|
||
# Examples | ||
|
||
```julia-repl | ||
@model function f(x) | ||
m ~ Normal(0, 1) | ||
x ~ Normal(m, 1) | ||
end | ||
|
||
model = f(1.5) | ||
mle = optimize(model, MLE()) | ||
|
||
# Use a different optimizer | ||
mle = optimize(model, MLE(), NelderMead()) | ||
``` | ||
""" | ||
function Optim.optimize(model::Model, ::MLE, args...; kwargs...) | ||
ctx = OptimizationContext(DynamicPPL.LikelihoodContext()) | ||
return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) | ||
end | ||
|
||
""" | ||
Optim.optimize(model::Model, ::MAP, args...; kwargs...) | ||
|
||
Compute a maximum a posterior estimate of the `model`. | ||
|
||
# Examples | ||
|
||
```julia-repl | ||
@model function f(x) | ||
m ~ Normal(0, 1) | ||
x ~ Normal(m, 1) | ||
end | ||
|
||
model = f(1.5) | ||
map_est = optimize(model, MAP()) | ||
|
||
# Use a different optimizer | ||
map_est = optimize(model, MAP(), NelderMead()) | ||
``` | ||
""" | ||
function Optim.optimize(model::Model, ::MAP, args...; kwargs...) | ||
ctx = OptimizationContext(DynamicPPL.DefaultContext()) | ||
return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) | ||
end | ||
|
||
""" | ||
Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) | ||
|
||
Estimate a mode, i.e., compute a MLE or MAP estimate. | ||
""" | ||
function Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be nice to allow an initial solution here as a named tuple using PriorContext. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean as a return value or a starting point? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that should be part of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I mean as a starting point. We also need to make sure the varinfo here is linked. If the same objective is optimized twice, the varinfo will not be linked here I think. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right on that -- I've linked it after the optimized values are extracted. Good catch. |
||
# Do some initialization. | ||
spl = DynamicPPL.SampleFromPrior() | ||
init_vals = f.vi[spl] | ||
|
||
# Optimize! | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
M = Optim.optimize(f, init_vals, optimizer, args...; kwargs...) | ||
|
||
# Get the VarInfo at the MLE/MAP point, and run the model to ensure | ||
# correct dimensionality. | ||
f.vi[spl] = M.minimizer | ||
invlink!(f.vi, spl) | ||
vals = f.vi[spl] | ||
link!(f.vi, spl) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Make one transition to get the parameter names. | ||
ts = [Turing.Inference.Transition(DynamicPPL.tonamedtuple(f.vi), DynamicPPL.getlogp(f.vi))] | ||
varnames, _ = Turing.Inference._params_to_array(ts) | ||
|
||
# Store the parameters and their names in an array. | ||
vmat = NamedArrays.NamedArray(vals, varnames) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return ModeResult(vmat, M, -M.minimum, f) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed: users will have problems if they load Optim but not NamedArrays. We should add a nested
@require
block that checks for NamedArrays if we want to keep it optional (which I guess we want). Otherwise one could think about returning a named tuple as default and just provide some optional way for converting it to a NamedArray.