-
Notifications
You must be signed in to change notification settings - Fork 226
Add MLE/MAP functionality #1230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Changes from 4 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
5f24ae3
Add Optim dependency
cpfiffer f8edbab
Export MLE/MAP
cpfiffer 7f6a03b
Fix stupid _params_to_array behavior
cpfiffer d682d9b
Add MLE/MAP
cpfiffer 3611719
Update src/modes/ModeEstimation.jl
cpfiffer 43b1cc8
Merge branch 'master' into csp/modes
cpfiffer 40c80db
Change optimizer.
cpfiffer 23a36bc
Merge branch 'master' into csp/modes
cpfiffer 2ba1c81
Match csp/hessian-bug
cpfiffer fd4a4de
Merge branch 'master' into csp/modes
cpfiffer dc0dd1c
Addressing comments, fixing bugs, adding tests
cpfiffer dca46e8
Removed extraneous model call
cpfiffer e616eb8
Add docstringsd
cpfiffer f9fa51f
Update src/modes/ModeEstimation.jl
cpfiffer 3fc51d9
Minor corrections.
cpfiffer 3e8c64c
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer b1a4861
Add NamedArrays to extras and compat
cpfiffer 4373c88
Fix dependencies
cpfiffer 0a769b7
Update tests & address comments
cpfiffer e660fb5
Correct Project.toml
cpfiffer 4fc1d73
Correct imports
cpfiffer cb56cd2
Renaming invlink
cpfiffer fab5b2d
Address comments
cpfiffer 1f2e8cb
Remove Optim from compat
cpfiffer d97f2dc
Minor correction
cpfiffer 16bbdf6
Update src/modes/ModeEstimation.jl
cpfiffer 403ddcd
Update src/modes/ModeEstimation.jl
cpfiffer 58fd2cf
Update src/modes/ModeEstimation.jl
cpfiffer 3eb5251
Use getval
cpfiffer a8b9869
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer 0809329
Update src/modes/ModeEstimation.jl
cpfiffer 065bc18
Tidying, fixing tests
cpfiffer e652e00
Merge branch 'csp/modes' of github.com:TuringLang/Turing.jl into csp/…
cpfiffer 1e5aea8
Replaced >= with >, because I am a fool
cpfiffer 199c00c
Use function notation for .~
cpfiffer bd0c1bb
Link the model vi after extracting optimized vals
cpfiffer 05ae0d4
Make sure linking status is right for Hessian
cpfiffer 345776c
Update src/Turing.jl
cpfiffer a2b159c
Update src/modes/ModeEstimation.jl
cpfiffer 856f282
Changer NamedArrays to dependency
cpfiffer 8cd87aa
Add warning if no convergence occurred, adapt to DynamicPPL master
cpfiffer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,8 @@ include("inference/Inference.jl") # inference algorithms | |
using .Inference | ||
include("variational/VariationalInference.jl") | ||
using .Variational | ||
include("modes/ModeEstimation.jl") | ||
using .ModeEstimation | ||
|
||
# TODO: re-design `sample` interface in MCMCChains, which unify CmdStan and Turing. | ||
# Related: https://github.com/TuringLang/Turing.jl/issues/746 | ||
|
@@ -85,7 +87,7 @@ export @model, # modelling | |
CSMC, | ||
PG, | ||
|
||
vi, # variational inference | ||
vi, # variational inference | ||
ADVI, | ||
|
||
sample, # inference | ||
|
@@ -108,5 +110,8 @@ export @model, # modelling | |
LogPoisson, | ||
NamedDist, | ||
filldist, | ||
arraydist | ||
arraydist, | ||
|
||
MLE, # mode estimation tools | ||
MAP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you move these exports to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You forgot to remove the exports here.
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
module ModeEstimation | ||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
using ..Turing | ||
using ..Bijectors | ||
using LinearAlgebra | ||
|
||
import ..DynamicPPL | ||
import Optim | ||
import NamedArrays | ||
import ..ForwardDiff | ||
|
||
export MAP, MLE | ||
|
||
""" | ||
ModeResult{ | ||
V<:NamedArrays.NamedArray, | ||
M<:NamedArrays.NamedArray, | ||
O<:Optim.MultivariateOptimizationResults, | ||
S<:NamedArrays.NamedArray | ||
} | ||
|
||
A wrapper struct to store various results from a MAP or MLE estimation. | ||
|
||
Fields: | ||
|
||
- `values` is a vector with the resulting point estimates | ||
- `info_matrix` is the inverse Hessian | ||
- `optim_result` is the stored Optim.jl results | ||
- `summary_table` is a summary table with parameters, standard errors, and | ||
t-statistics computed from the information matrix. | ||
- `lp` is the final likelihood. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
struct ModeResult{ | ||
V<:NamedArrays.NamedArray, | ||
M<:Union{Missing, NamedArrays.NamedArray}, | ||
O<:Optim.MultivariateOptimizationResults, | ||
S<:NamedArrays.NamedArray | ||
} | ||
values :: V | ||
info_matrix :: M | ||
optim_result :: O | ||
summary_table :: S | ||
lp :: Float64 | ||
end | ||
|
||
function Base.show(io::IO, m::ModeResult) | ||
show(io, m.summary_table) | ||
end | ||
|
||
""" | ||
make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext) | ||
|
||
Constructs a log density function that accepts a vector `z` and returns | ||
a tuple (-likelihood, `varinfo`). The model is run using the provided | ||
context `ctx`. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
function make_logjoint(model::DynamicPPL.Model, ctx::DynamicPPL.AbstractContext) | ||
# setup | ||
varinfo_init = Turing.VarInfo(model) | ||
spl = DynamicPPL.SampleFromPrior() | ||
DynamicPPL.link!(varinfo_init, spl) | ||
|
||
function logπ(z; unlinked = false) | ||
varinfo = DynamicPPL.VarInfo(varinfo_init, spl, z) | ||
|
||
unlinked && DynamicPPL.invlink!(varinfo_init, spl) | ||
model(varinfo, spl, ctx) | ||
unlinked && DynamicPPL.link!(varinfo_init, spl) | ||
|
||
return -DynamicPPL.getlogp(varinfo) | ||
end | ||
|
||
return logπ | ||
end | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
""" | ||
mode_estimation( | ||
model::DynamicPPL.Model, | ||
lpf; | ||
optim_options=Optim.Options(), | ||
kwargs... | ||
) | ||
|
||
An internal function that handles the computation of a MLE or MAP estimate. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Arguments: | ||
|
||
- `model` is a `DynamicPPL.Model`. | ||
- `lpf` is a function returned by `make_logjoint`. | ||
|
||
Optional arguments: | ||
|
||
- `optim_options` is a `Optim.Options` struct that allows you to change the number | ||
of iterations run in an MLE estimate. | ||
|
||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
function mode_estimation( | ||
model::DynamicPPL.Model, | ||
lpf; | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
optim_options=Optim.Options(), | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
kwargs... | ||
) | ||
# Do some initialization. | ||
b = bijector(model) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
binv = inv(b) | ||
|
||
spl = DynamicPPL.SampleFromPrior() | ||
vi = DynamicPPL.VarInfo(model) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
init_params = model(vi, spl) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
init_vals = vi[spl] | ||
|
||
# Construct target function. | ||
target(x) = lpf(x) | ||
hess_target(x) = lpf(x; unlinked=true) | ||
|
||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Optimize! | ||
M = Optim.optimize(target, init_vals, optim_options) | ||
|
||
# Retrieve the estimated values. | ||
vals = binv(M.minimizer) | ||
|
||
# Get the VarInfo at the MLE/MAP point, and run the model to ensure | ||
# correct dimensionality. | ||
vi[spl] = vals | ||
model(vi) # XXX: Is this a necessary step? | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Make one transition to get the parameter names. | ||
ts = [Turing.Inference.Transition(DynamicPPL.tonamedtuple(vi), DynamicPPL.getlogp(vi))] | ||
varnames, _ = Turing.Inference._params_to_array(ts) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Store the parameters and their names in an array. | ||
vmat = NamedArrays.NamedArray(vals, varnames) | ||
|
||
# Try to generate the information matrix. | ||
try | ||
# Calculate Hessian and information matrix. | ||
info = ForwardDiff.hessian(hess_target, vals) | ||
info = inv(info) | ||
mat = NamedArrays.NamedArray(info, (varnames, varnames)) | ||
|
||
# Create the standard errors. | ||
ses = sqrt.(diag(info)) | ||
|
||
# Calculate t-stats. | ||
tstat = vals ./ ses | ||
|
||
# Make a summary table. | ||
stable = NamedArrays.NamedArray( | ||
[vals ses tstat], | ||
(varnames, ["parameter", "std_err", "tstat"])) | ||
|
||
# Return a wrapped-up table. | ||
return ModeResult(vmat, mat, M, stable, M.minimum) | ||
catch err | ||
@warn "Could not compute Hessian matrix" err | ||
stable = NamedArrays.NamedArray([vals repeat([missing], length(vals)) repeat([missing], length(vals))], (varnames, ["parameter", "std_err", "tstat"])) | ||
return ModeResult(vmat, missing, M, stable, M.minimum) | ||
end | ||
end | ||
|
||
""" | ||
MLE(model::DynamicPPL.Model; kwargs...) | ||
|
||
Returns a maximum likelihood estimate of the given `model`. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Arguments: | ||
|
||
- `model` is a `DynamicPPL.Model`. | ||
|
||
Keyword arguments: | ||
|
||
- `optim_options` is a `Optim.Options` struct that allows you to change the number | ||
of iterations run in an MLE estimate. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Usage: | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```julia | ||
using Turing | ||
|
||
@model function f() | ||
m ~ Normal(0, 1) | ||
1.5 ~ Normal(m, 1) | ||
2.0 ~ Normal(m, 1) | ||
end | ||
|
||
model = f() | ||
mle_estimate = MLE(model) | ||
|
||
# Manually setting the optimizers settings. | ||
mle_estimate = MLE(model, optim_options=Optim.Options(iterations=500)) | ||
``` | ||
""" | ||
function MLE(model::DynamicPPL.Model; kwargs...) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lpf = make_logjoint(model, DynamicPPL.LikelihoodContext()) | ||
return mode_estimation(model, lpf; kwargs...) | ||
end | ||
|
||
""" | ||
MAP(model::DynamicPPL.Model; kwargs...) | ||
|
||
Returns the maximum a posteriori estimate of the given `model`. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Arguments: | ||
|
||
- `model` is a `DynamicPPL.Model`. | ||
|
||
Keyword arguments: | ||
|
||
- `optim_options` is a `Optim.Options` struct that allows you to change the number | ||
of iterations run in an MLE estimate. | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Usage: | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```julia | ||
using Turing | ||
|
||
@model function f() | ||
m ~ Normal(0, 1) | ||
1.5 ~ Normal(m, 1) | ||
2.0 ~ Normal(m, 1) | ||
end | ||
|
||
model = f() | ||
mle_estimate = MAP(model) | ||
|
||
# Manually setting the optimizers settings. | ||
mle_estimate = MAP(model, optim_options=Optim.Options(iterations=500)) | ||
``` | ||
""" | ||
function MAP(model::DynamicPPL.Model; kwargs...) | ||
cpfiffer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
lpf = make_logjoint(model, DynamicPPL.DefaultContext()) | ||
return mode_estimation(model, lpf; kwargs...) | ||
end | ||
|
||
end #module |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.