Skip to content

Add MLE/MAP functionality #1230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
May 23, 2020
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5f24ae3
Add Optim dependency
cpfiffer Apr 23, 2020
f8edbab
Export MLE/MAP
cpfiffer Apr 23, 2020
7f6a03b
Fix stupid _params_to_array behavior
cpfiffer Apr 23, 2020
d682d9b
Add MLE/MAP
cpfiffer Apr 23, 2020
3611719
Update src/modes/ModeEstimation.jl
cpfiffer Apr 23, 2020
43b1cc8
Merge branch 'master' into csp/modes
cpfiffer Apr 28, 2020
40c80db
Change optimizer.
cpfiffer Apr 28, 2020
23a36bc
Merge branch 'master' into csp/modes
cpfiffer Apr 28, 2020
2ba1c81
Match csp/hessian-bug
cpfiffer May 1, 2020
fd4a4de
Merge branch 'master' into csp/modes
cpfiffer May 8, 2020
dc0dd1c
Addressing comments, fixing bugs, adding tests
cpfiffer May 8, 2020
dca46e8
Removed extraneous model call
cpfiffer May 8, 2020
e616eb8
Add docstringsd
cpfiffer May 8, 2020
f9fa51f
Update src/modes/ModeEstimation.jl
cpfiffer May 8, 2020
3fc51d9
Minor corrections.
cpfiffer May 10, 2020
3e8c64c
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 10, 2020
b1a4861
Add NamedArrays to extras and compat
cpfiffer May 12, 2020
4373c88
Fix dependencies
cpfiffer May 12, 2020
0a769b7
Update tests & address comments
cpfiffer May 19, 2020
e660fb5
Correct Project.toml
cpfiffer May 19, 2020
4fc1d73
Correct imports
cpfiffer May 19, 2020
cb56cd2
Renaming invlink
cpfiffer May 19, 2020
fab5b2d
Address comments
cpfiffer May 20, 2020
1f2e8cb
Remove Optim from compat
cpfiffer May 20, 2020
d97f2dc
Minor correction
cpfiffer May 20, 2020
16bbdf6
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
403ddcd
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
58fd2cf
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
3eb5251
Use getval
cpfiffer May 20, 2020
a8b9869
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 20, 2020
0809329
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
065bc18
Tidying, fixing tests
cpfiffer May 20, 2020
e652e00
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 20, 2020
1e5aea8
Replaced >= with >, because I am a fool
cpfiffer May 20, 2020
199c00c
Use function notation for .~
cpfiffer May 20, 2020
bd0c1bb
Link the model vi after extracting optimized vals
cpfiffer May 20, 2020
05ae0d4
Make sure linking status is right for Hessian
cpfiffer May 20, 2020
345776c
Update src/Turing.jl
cpfiffer May 21, 2020
a2b159c
Update src/modes/ModeEstimation.jl
cpfiffer May 21, 2020
856f282
Changer NamedArrays to dependency
cpfiffer May 21, 2020
8cd87aa
Add warning if no convergence occurred, adapt to DynamicPPL master
cpfiffer May 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -41,6 +42,7 @@ ForwardDiff = "0.10.3"
Libtask = "0.4"
LogDensityProblems = "^0.9, 0.10"
MCMCChains = "3.0.7"
Optim = "0.20, 0.21"
ProgressLogging = "0.1"
Reexport = "0.2.0"
Requires = "0.5, 1.0"
Expand All @@ -56,6 +58,8 @@ CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand All @@ -66,4 +70,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"]
test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization", "Optim", "NamedArrays"]
12 changes: 10 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ using .Variational
end
end

@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin
include("modes/ModeEstimation.jl")
export MAP, MLE, optimize
end
Comment on lines +61 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed: users will have problems if they load Optim but not NamedArrays. We should add a nested @require block that checks for NamedArrays if we want to keep it optional (which I guess we want). Otherwise one could think about returning a named tuple as default and just provide some optional way for converting it to a NamedArray.


###########
# Exports #
###########
Expand Down Expand Up @@ -87,7 +92,7 @@ export @model, # modelling
CSMC,
PG,

vi, # variational inference
vi, # variational inference
ADVI,

sample, # inference
Expand All @@ -110,5 +115,8 @@ export @model, # modelling
LogPoisson,
NamedDist,
filldist,
arraydist
arraydist,

MLE, # mode estimation tools
MAP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move these exports to the @require block above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You forgot to remove the exports here.

end
12 changes: 7 additions & 5 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,21 @@ Return a named tuple of parameters.
getparams(t) = t.θ
getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t))

function _params_to_array(ts)
names_set = Set{String}()
function _params_to_array(ts::Vector)
names = Vector{String}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms, vs = flatten_namedtuple(getparams(t))
for nm in nms
push!(names_set, nm)
if !(nm in names)
push!(names, nm)
end
end
# Convert the names and values to a single dictionary.
return Dict(nms[j] => vs[j] for j in 1:length(vs))
end
names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
# names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
(j, key) in enumerate(names)]

return names, vals
Expand Down
284 changes: 284 additions & 0 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
using ..Turing
using ..Bijectors
using LinearAlgebra

import ..AbstractMCMC: AbstractSampler
import ..DynamicPPL
import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName,
_getindex, getsym, getfield, settrans!, setorder!,
get_and_set_val!, istrans, tilde, dot_tilde
import Optim
import Optim: optimize
import NamedArrays
import ..ForwardDiff
import StatsBase
import Printf

struct MLE end
struct MAP end

"""
OptimizationContext{C<:AbstractContext} <: AbstractContext

The `OptimizationContext` transforms variables to their constrained space, but
does not use the density with respect to the transformation. This context is
intended to allow an optimizer to sample in R^n freely.
"""
struct OptimizationContext{C<:AbstractContext} <: AbstractContext
context::C
end

# assume
function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi)
r = vi[vn]
return r, 0
end

function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi)
r = vi[vn]
return r, Distributions.logpdf(dist, r)
end

# observe
function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi)
return Distributions.logpdf(dist, value)
end

# dot assume
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
return r, 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
return r, loglikelihood(dist, r)
end

# dot observe
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
return 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
return loglikelihood(dist, r)
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi)
return sum(Distributions.logpdf.(dists, value))
end

function getval(
vi,
vns::AbstractVector{<:VarName},
)
r = vi[vns]
return r
end

function getval(
vi,
vns::AbstractArray{<:VarName},
)
r = reshape(vi[vec(vns)], size(vns))
return r
end

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}

A struct that stores the log density function of a `DynamicPPL` model.
"""
struct OptimLogDensity{M<:Model,C<:AbstractContext,V<:VarInfo}
"A `DynamicPPL.Model` constructed either with the `@model` macro or manually."
model::M
"A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE."
context::C
"A `DynamicPPL.VarInfo` struct that will be used to update model parameters."
vi::V
end

"""
OptimLogDensity(model::Model, context::AbstractContext)

Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`.
"""
function OptimLogDensity(model::Model, context::AbstractContext)
init = VarInfo(model)
DynamicPPL.link!(init, DynamicPPL.SampleFromPrior())
return OptimLogDensity(model, context, init)
end

"""
(f::OptimLogDensity)(z)

Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
at the array `z`.
"""
function (f::OptimLogDensity)(z)
spl = DynamicPPL.SampleFromPrior()

varinfo = DynamicPPL.VarInfo(f.vi, spl, z)
f.model(varinfo, spl, f.context)
return -DynamicPPL.getlogp(varinfo)
end

"""
ModeResult{
V<:NamedArrays.NamedArray,
M<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
S<:NamedArrays.NamedArray
}

A wrapper struct to store various results from a MAP or MLE estimation.
"""
struct ModeResult{
V<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
M<:OptimLogDensity
} <: StatsBase.StatisticalModel
"A vector with the resulting point estimates."
values :: V
"The stored Optim.jl results."
optim_result :: O
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run."
lp :: Float64
"The evaluation function used to calculate the output."
f :: M
end
#############################
# Various StatsBase methods #
#############################



function Base.show(io::IO, ::MIME"text/plain", m::ModeResult)
print(io, "ModeResult with minimized lp of ")
Printf.@printf(io, "%.2f", m.lp)
println(io)
show(io, m.values)
end

function Base.show(io::IO, m::ModeResult)
show(io, m.values.array)
end

function StatsBase.coeftable(m::ModeResult)
# Get columns for coeftable.
terms = StatsBase.coefnames(m)
estimates = m.values.array[:,1]
stderrors = StatsBase.stderror(m)
tstats = estimates ./ stderrors

StatsBase.CoefTable([estimates, stderrors, tstats], ["estimate", "stderror", "tstat"], terms)
end

function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...)
# Calculate Hessian and information matrix.
varnames = StatsBase.coefnames(m)
info = inv(hessian_function(m.f, m.values.array[:, 1]))
return NamedArrays.NamedArray(info, (varnames, varnames))
end

StatsBase.coef(m::ModeResult) = m.values
StatsBase.coefnames(m::ModeResult) = names(m.values)[1]
StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
StatsBase.vcov(m::ModeResult) = StatsBase.informationmatrix(m)
StatsBase.loglikelihood(m::ModeResult) = m.lp

####################
# Optim.jl methods #
####################

"""
Optim.optimize(model::Model, ::MLE, args...; kwargs...)

Compute a maximum likelihood estimate of the `model`.

# Examples

```julia-repl
@model function f(x)
m ~ Normal(0, 1)
x ~ Normal(m, 1)
end

model = f(1.5)
mle = optimize(model, MLE())

# Use a different optimizer
mle = optimize(model, MLE(), NelderMead())
```
"""
function Optim.optimize(model::Model, ::MLE, args...; kwargs...)
ctx = OptimizationContext(DynamicPPL.LikelihoodContext())
return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Optim.optimize(model::Model, ::MAP, args...; kwargs...)

Compute a maximum a posterior estimate of the `model`.

# Examples

```julia-repl
@model function f(x)
m ~ Normal(0, 1)
x ~ Normal(m, 1)
end

model = f(1.5)
map_est = optimize(model, MAP())

# Use a different optimizer
map_est = optimize(model, MAP(), NelderMead())
```
"""
function Optim.optimize(model::Model, ::MAP, args...; kwargs...)
ctx = OptimizationContext(DynamicPPL.DefaultContext())
return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...)
end

"""
Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)

Estimate a mode, i.e., compute a MLE or MAP estimate.
"""
function Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be nice to allow an initial solution here as a named tuple using PriorContext.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean as a return value or a starting point?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be part of the OptimLogDensity constructor ideally and not be handled here. For convenience, one could then forward it from the MAP and MLE methods (or include it as a field?) when OptimLogDensity is constructed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean as a return value or a starting point?

I mean as a starting point. We also need to make sure the varinfo here is linked. If the same objective is optimized twice, the varinfo will not be linked here I think.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right on that -- I've linked it after the optimized values are extracted. Good catch.

# Do some initialization.
spl = DynamicPPL.SampleFromPrior()
init_vals = f.vi[spl]

# Optimize!
M = Optim.optimize(f, init_vals, optimizer, args...; kwargs...)

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
f.vi[spl] = M.minimizer
invlink!(f.vi, spl)
vals = f.vi[spl]
link!(f.vi, spl)

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(DynamicPPL.tonamedtuple(f.vi), DynamicPPL.getlogp(f.vi))]
varnames, _ = Turing.Inference._params_to_array(ts)

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)

return ModeResult(vmat, M, -M.minimum, f)
end
Loading