Skip to content

Add MLE/MAP functionality #1230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
May 23, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
5f24ae3
Add Optim dependency
cpfiffer Apr 23, 2020
f8edbab
Export MLE/MAP
cpfiffer Apr 23, 2020
7f6a03b
Fix stupid _params_to_array behavior
cpfiffer Apr 23, 2020
d682d9b
Add MLE/MAP
cpfiffer Apr 23, 2020
3611719
Update src/modes/ModeEstimation.jl
cpfiffer Apr 23, 2020
43b1cc8
Merge branch 'master' into csp/modes
cpfiffer Apr 28, 2020
40c80db
Change optimizer.
cpfiffer Apr 28, 2020
23a36bc
Merge branch 'master' into csp/modes
cpfiffer Apr 28, 2020
2ba1c81
Match csp/hessian-bug
cpfiffer May 1, 2020
fd4a4de
Merge branch 'master' into csp/modes
cpfiffer May 8, 2020
dc0dd1c
Addressing comments, fixing bugs, adding tests
cpfiffer May 8, 2020
dca46e8
Removed extraneous model call
cpfiffer May 8, 2020
e616eb8
Add docstringsd
cpfiffer May 8, 2020
f9fa51f
Update src/modes/ModeEstimation.jl
cpfiffer May 8, 2020
3fc51d9
Minor corrections.
cpfiffer May 10, 2020
3e8c64c
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 10, 2020
b1a4861
Add NamedArrays to extras and compat
cpfiffer May 12, 2020
4373c88
Fix dependencies
cpfiffer May 12, 2020
0a769b7
Update tests & address comments
cpfiffer May 19, 2020
e660fb5
Correct Project.toml
cpfiffer May 19, 2020
4fc1d73
Correct imports
cpfiffer May 19, 2020
cb56cd2
Renaming invlink
cpfiffer May 19, 2020
fab5b2d
Address comments
cpfiffer May 20, 2020
1f2e8cb
Remove Optim from compat
cpfiffer May 20, 2020
d97f2dc
Minor correction
cpfiffer May 20, 2020
16bbdf6
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
403ddcd
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
58fd2cf
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
3eb5251
Use getval
cpfiffer May 20, 2020
a8b9869
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 20, 2020
0809329
Update src/modes/ModeEstimation.jl
cpfiffer May 20, 2020
065bc18
Tidying, fixing tests
cpfiffer May 20, 2020
e652e00
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer May 20, 2020
1e5aea8
Replaced >= with >, because I am a fool
cpfiffer May 20, 2020
199c00c
Use function notation for .~
cpfiffer May 20, 2020
bd0c1bb
Link the model vi after extracting optimized vals
cpfiffer May 20, 2020
05ae0d4
Make sure linking status is right for Hessian
cpfiffer May 20, 2020
345776c
Update src/Turing.jl
cpfiffer May 21, 2020
a2b159c
Update src/modes/ModeEstimation.jl
cpfiffer May 21, 2020
856f282
Changer NamedArrays to dependency
cpfiffer May 21, 2020
8cd87aa
Add warning if no convergence occurred, adapt to DynamicPPL master
cpfiffer May 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -41,6 +42,7 @@ ForwardDiff = "0.10.3"
Libtask = "0.3.1"
LogDensityProblems = "^0.9, 0.10"
MCMCChains = "3.0.7"
Optim = "0.20"
ProgressLogging = "0.1"
Reexport = "0.2.0"
Requires = "0.5, 1.0"
Expand Down
9 changes: 7 additions & 2 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ include("inference/Inference.jl") # inference algorithms
using .Inference
include("variational/VariationalInference.jl")
using .Variational
include("modes/ModeEstimation.jl")
using .ModeEstimation

# TODO: re-design `sample` interface in MCMCChains, which unify CmdStan and Turing.
# Related: https://github.com/TuringLang/Turing.jl/issues/746
Expand Down Expand Up @@ -85,7 +87,7 @@ export @model, # modelling
CSMC,
PG,

vi, # variational inference
vi, # variational inference
ADVI,

sample, # inference
Expand All @@ -108,5 +110,8 @@ export @model, # modelling
LogPoisson,
NamedDist,
filldist,
arraydist
arraydist,

MLE, # mode estimation tools
MAP
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move these exports to the @require block above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You forgot to remove the exports here.

end
12 changes: 7 additions & 5 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,20 @@ end
# Chain making utilities #
##########################

function _params_to_array(ts::Vector, spl::Sampler)
names_set = Set{String}()
function _params_to_array(ts::Vector)
names = Vector{String}()
# Extract the parameter names and values from each transition.
dicts = map(ts) do t
nms, vs = flatten_namedtuple(t.θ)
for nm in nms
push!(names_set, nm)
if !(nm in names)
push!(names, nm)
end
end
# Convert the names and values to a single dictionary.
return Dict(nms[j] => vs[j] for j in 1:length(vs))
end
names = collect(names_set)
# names = collect(names_set)
vals = [get(dicts[i], key, missing) for i in eachindex(dicts),
(j, key) in enumerate(names)]

Expand Down Expand Up @@ -356,7 +358,7 @@ function AbstractMCMC.bundle_samples(

# Convert transitions to array format.
# Also retrieve the variable names.
nms, vals = _params_to_array(ts, spl)
nms, vals = _params_to_array(ts)

# Get the values of the extra parameters in each Transition struct.
extra_params, extra_values = get_transition_extras(ts)
Expand Down
235 changes: 235 additions & 0 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
module ModeEstimation

using ..Turing
using ..Bijectors
using LinearAlgebra

import ..DynamicPPL
import Optim
import NamedArrays
import ..ForwardDiff

export MAP, MLE

"""
ModeResult{
V<:NamedArrays.NamedArray,
M<:NamedArrays.NamedArray,
O<:Optim.MultivariateOptimizationResults,
S<:NamedArrays.NamedArray
}

A wrapper struct to store various results from a MAP or MLE estimation.

Fields:

- `values` is a vector with the resulting point estimates
- `info_matrix` is the inverse Hessian
- `optim_result` is the stored Optim.jl results
- `summary_table` is a summary table with parameters, standard errors, and
t-statistics computed from the information matrix.
- `lp` is the final likelihood.
"""
struct ModeResult{
V<:NamedArrays.NamedArray,
M<:Union{Missing, NamedArrays.NamedArray},
O<:Optim.MultivariateOptimizationResults,
S<:NamedArrays.NamedArray
}
values :: V
info_matrix :: M
optim_result :: O
summary_table :: S
lp :: Float64
end

function Base.show(io::IO, m::ModeResult)
show(io, m.summary_table)
end

"""
make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext)

Constructs a log density function that accepts a vector `z` and returns
a tuple (-likelihood, `varinfo`). The model is run using the provided
context `ctx`.
"""
function make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext)
# setup
varinfo_init = Turing.VarInfo(model)
spl = DynamicPPL.SampleFromPrior()
DynamicPPL.link!(varinfo_init, spl)

function logπ(z; unlinked = false)
varinfo = DynamicPPL.VarInfo(varinfo_init, spl, z)

unlinked && DynamicPPL.invlink!(varinfo_init, spl)
model(varinfo, spl, ctx)
unlinked && DynamicPPL.link!(varinfo_init, spl)

return -DynamicPPL.getlogp(varinfo)
end

return logπ
end

"""
mode_estimation(
model::DynamicPPL.Model,
lpf;
optim_options=Optim.Options(),
kwargs...
)

An internal function that handles the computation of a MLE or MAP estimate.

Arguments:

- `model` is a `DynamicPPL.Model`.
- `lpf` is a function returned by `make_logjoint`.

Optional arguments:

- `optim_options` is a `Optim.Options` struct that allows you to change the number
of iterations run in an MLE estimate.

"""
function mode_estimation(
model::DynamicPPL.Model,
lpf;
optim_options=Optim.Options(),
kwargs...
)
# Do some initialization.
b = bijector(model)
binv = inv(b)

spl = DynamicPPL.SampleFromPrior()
vi = DynamicPPL.VarInfo(model)
init_params = model(vi, spl)
init_vals = vi[spl]

# Construct target function.
target(x) = lpf(x)
hess_target(x) = lpf(x; unlinked=true)

# Optimize!
M = Optim.optimize(target, init_vals, optim_options)

# Retrieve the estimated values.
vals = binv(M.minimizer)

# Get the VarInfo at the MLE/MAP point, and run the model to ensure
# correct dimensionality.
vi[spl] = vals
model(vi) # XXX: Is this a necessary step?

# Make one transition to get the parameter names.
ts = [Turing.Inference.Transition(DynamicPPL.tonamedtuple(vi), DynamicPPL.getlogp(vi))]
varnames, _ = Turing.Inference._params_to_array(ts)

# Store the parameters and their names in an array.
vmat = NamedArrays.NamedArray(vals, varnames)

# Try to generate the information matrix.
try
# Calculate Hessian and information matrix.
info = ForwardDiff.hessian(hess_target, vals)
info = inv(info)
mat = NamedArrays.NamedArray(info, (varnames, varnames))

# Create the standard errors.
ses = sqrt.(diag(info))

# Calculate t-stats.
tstat = vals ./ ses

# Make a summary table.
stable = NamedArrays.NamedArray(
[vals ses tstat],
(varnames, ["parameter", "std_err", "tstat"]))

# Return a wrapped-up table.
return ModeResult(vmat, mat, M, stable, M.minimum)
catch err
@warn "Could not compute Hessian matrix" err
stable = NamedArrays.NamedArray([vals repeat([missing], length(vals)) repeat([missing], length(vals))], (varnames, ["parameter", "std_err", "tstat"]))
return ModeResult(vmat, missing, M, stable, M.minimum)
end
end

"""
MLE(model::DynamicPPL.Model; kwargs...)

Returns a maximum likelihood estimate of the given `model`.

Arguments:

- `model` is a `DynamicPPL.Model`.

Keyword arguments:

- `optim_options` is a `Optim.Options` struct that allows you to change the number
of iterations run in an MLE estimate.

Usage:

```julia
using Turing

@model function f()
m ~ Normal(0, 1)
1.5 ~ Normal(m, 1)
2.0 ~ Normal(m, 1)
end

model = f()
mle_estimate = MLE(model)

# Manually setting the optimizers settings.
mle_estimate = MLE(model, optim_options=Optim.Options(iterations=500))
```
"""
function MLE(model::DynamicPPL.Model; kwargs...)
lpf = make_logjoint(model, DynamicPPL.LikelihoodContext())
return mode_estimation(model, lpf; kwargs...)
end

"""
MAP(model::DynamicPPL.Model; kwargs...)

Returns the maximum a posteriori estimate of the given `model`.

Arguments:

- `model` is a `DynamicPPL.Model`.

Keyword arguments:

- `optim_options` is a `Optim.Options` struct that allows you to change the number
of iterations run in an MLE estimate.

Usage:

```julia
using Turing

@model function f()
m ~ Normal(0, 1)
1.5 ~ Normal(m, 1)
2.0 ~ Normal(m, 1)
end

model = f()
mle_estimate = MAP(model)

# Manually setting the optimizers settings.
mle_estimate = MAP(model, optim_options=Optim.Options(iterations=500))
```
"""
function MAP(model::DynamicPPL.Model; kwargs...)
lpf = make_logjoint(model, DynamicPPL.DefaultContext())
return mode_estimation(model, lpf; kwargs...)
end

end #module