diff --git a/Project.toml b/Project.toml index 716609900..f44160217 100644 --- a/Project.toml +++ b/Project.toml @@ -17,6 +17,8 @@ Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -41,6 +43,8 @@ ForwardDiff = "0.10.3" Libtask = "0.4" LogDensityProblems = "^0.9, 0.10" MCMCChains = "3.0.7" +NamedArrays = "0.9" +Optim = "0.20, 0.21" ProgressLogging = "0.1" Reexport = "0.2.0" Requires = "0.5, 1.0" @@ -56,6 +60,7 @@ CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -66,4 +71,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"] +test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization", "Optim"] diff --git a/src/Turing.jl b/src/Turing.jl index f24822c47..35fab3da3 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -58,6 +58,11 @@ using .Variational end end +@init @require Optim="429524aa-4258-5aef-a3af-852621145aeb" @eval begin + include("modes/ModeEstimation.jl") + export MAP, MLE, optimize +end + ########### # Exports # ########### @@ -87,7 +92,7 @@ export @model, # modelling CSMC, PG, - vi, # variational inference + vi, # variational inference ADVI, sample, # inference diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index b3db2982d..19e61ba4d 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -305,19 +305,21 @@ Return a named tuple of parameters. getparams(t) = t.θ getparams(t::VarInfo) = tonamedtuple(TypedVarInfo(t)) -function _params_to_array(ts) - names_set = Set{String}() +function _params_to_array(ts::Vector) + names = Vector{String}() # Extract the parameter names and values from each transition. dicts = map(ts) do t nms, vs = flatten_namedtuple(getparams(t)) for nm in nms - push!(names_set, nm) + if !(nm in names) + push!(names, nm) + end end # Convert the names and values to a single dictionary. return Dict(nms[j] => vs[j] for j in 1:length(vs)) end - names = collect(names_set) - vals = [get(dicts[i], key, missing) for i in eachindex(dicts), + # names = collect(names_set) + vals = [get(dicts[i], key, missing) for i in eachindex(dicts), (j, key) in enumerate(names)] return names, vals diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl new file mode 100644 index 000000000..00f8c505e --- /dev/null +++ b/src/modes/ModeEstimation.jl @@ -0,0 +1,320 @@ +using ..Turing +using ..Bijectors +using LinearAlgebra + +import ..AbstractMCMC: AbstractSampler +import ..DynamicPPL +import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName, + _getindex, getsym, getfield, settrans!, setorder!, + get_and_set_val!, istrans, tilde, dot_tilde, get_vns_and_dist +import .Optim +import .Optim: optimize +import ..ForwardDiff +import NamedArrays +import StatsBase +import Printf + +struct MLE end +struct MAP end + +""" + OptimizationContext{C<:AbstractContext} <: AbstractContext + +The `OptimizationContext` transforms variables to their constrained space, but +does not use the density with respect to the transformation. This context is +intended to allow an optimizer to sample in R^n freely. +""" +struct OptimizationContext{C<:AbstractContext} <: AbstractContext + context::C +end + +# assume +function DynamicPPL.tilde(rng, ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) + return DynamicPPL.tilde(ctx, spl, dist, vn, inds, vi) +end + +function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi) + r = vi[vn] + return r, 0 +end + +function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) + r = vi[vn] + return r, Distributions.logpdf(dist, r) +end + + +# observe +function DynamicPPL.tilde(rng, ctx::OptimizationContext, sampler, right, left, vi) + return DynamicPPL.tilde(ctx, sampler, right, left, vi) +end + +function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) + return 0 +end + +function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi) + return Distributions.logpdf(dist, value) +end + +# dot assume +function DynamicPPL.dot_tilde(rng, ctx::OptimizationContext, sampler, right, left, vn::VarName, inds, vi) + return DynamicPPL.dot_tilde(ctx, sampler, right, left, vn, inds, vi) +end + +function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi) + vns, dist = get_vns_and_dist(right, left, vn) + r = getval(vi, vns) + return r, 0 +end + +function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi) + vns, dist = get_vns_and_dist(right, left, vn) + r = getval(vi, vns) + return r, loglikelihood(dist, r) +end + +# dot observe +function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) + return 0 +end + +function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) + return 0 +end + +function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi) + vns, dist = get_vns_and_dist(right, left, vn) + r = getval(vi, vns) + return loglikelihood(dist, r) +end + +function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi) + return sum(Distributions.logpdf.(dists, value)) +end + +function getval( + vi, + vns::AbstractVector{<:VarName}, +) + r = vi[vns] + return r +end + +function getval( + vi, + vns::AbstractArray{<:VarName}, +) + r = reshape(vi[vec(vns)], size(vns)) + return r +end + +""" + OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} + +A struct that stores the log density function of a `DynamicPPL` model. +""" +struct OptimLogDensity{M<:Model,C<:AbstractContext,V<:VarInfo} + "A `DynamicPPL.Model` constructed either with the `@model` macro or manually." + model::M + "A `DynamicPPL.AbstractContext` used to evaluate the model. `LikelihoodContext` or `DefaultContext` are typical for MAP/MLE." + context::C + "A `DynamicPPL.VarInfo` struct that will be used to update model parameters." + vi::V +end + +""" + OptimLogDensity(model::Model, context::AbstractContext) + +Create a callable `OptimLogDensity` struct that evaluates a model using the given `context`. +""" +function OptimLogDensity(model::Model, context::AbstractContext) + init = VarInfo(model) + DynamicPPL.link!(init, DynamicPPL.SampleFromPrior()) + return OptimLogDensity(model, context, init) +end + +""" + (f::OptimLogDensity)(z) + +Evaluate the log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`) +at the array `z`. +""" +function (f::OptimLogDensity)(z) + spl = DynamicPPL.SampleFromPrior() + + varinfo = DynamicPPL.VarInfo(f.vi, spl, z) + f.model(varinfo, spl, f.context) + return -DynamicPPL.getlogp(varinfo) +end + +""" + ModeResult{ + V<:NamedArrays.NamedArray, + M<:NamedArrays.NamedArray, + O<:Optim.MultivariateOptimizationResults, + S<:NamedArrays.NamedArray + } + +A wrapper struct to store various results from a MAP or MLE estimation. +""" +struct ModeResult{ + V<:NamedArrays.NamedArray, + O<:Optim.MultivariateOptimizationResults, + M<:OptimLogDensity +} <: StatsBase.StatisticalModel + "A vector with the resulting point estimates." + values :: V + "The stored Optim.jl results." + optim_result :: O + "The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run." + lp :: Float64 + "The evaluation function used to calculate the output." + f :: M +end +############################# +# Various StatsBase methods # +############################# + + + +function Base.show(io::IO, ::MIME"text/plain", m::ModeResult) + print(io, "ModeResult with minimized lp of ") + Printf.@printf(io, "%.2f", m.lp) + println(io) + show(io, m.values) +end + +function Base.show(io::IO, m::ModeResult) + show(io, m.values.array) +end + +function StatsBase.coeftable(m::ModeResult) + # Get columns for coeftable. + terms = StatsBase.coefnames(m) + estimates = m.values.array[:,1] + stderrors = StatsBase.stderror(m) + tstats = estimates ./ stderrors + + StatsBase.CoefTable([estimates, stderrors, tstats], ["estimate", "stderror", "tstat"], terms) +end + +function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...) + # Calculate Hessian and information matrix. + + # Convert the values to their unconstrained states to make sure the + # Hessian is computed with respect to the untransformed parameters. + spl = DynamicPPL.SampleFromPrior() + + # NOTE: This should be converted to islinked(vi, spl) after + # https://github.com/TuringLang/DynamicPPL.jl/pull/124 goes through. + vns = DynamicPPL._getvns(m.f.vi, spl) + + linked = DynamicPPL._islinked(m.f.vi, vns) + linked && invlink!(m.f.vi, spl) + + # Calculate the Hessian. + varnames = StatsBase.coefnames(m) + H = hessian_function(m.f, m.values.array[:, 1]) + info = inv(H) + + # Link it back if we invlinked it. + linked && link!(m.f.vi, spl) + + return NamedArrays.NamedArray(info, (varnames, varnames)) +end + +StatsBase.coef(m::ModeResult) = m.values +StatsBase.coefnames(m::ModeResult) = names(m.values)[1] +StatsBase.params(m::ModeResult) = StatsBase.coefnames(m) +StatsBase.vcov(m::ModeResult) = StatsBase.informationmatrix(m) +StatsBase.loglikelihood(m::ModeResult) = m.lp + +#################### +# Optim.jl methods # +#################### + +""" + Optim.optimize(model::Model, ::MLE, args...; kwargs...) + +Compute a maximum likelihood estimate of the `model`. + +# Examples + +```julia-repl +@model function f(x) + m ~ Normal(0, 1) + x ~ Normal(m, 1) +end + +model = f(1.5) +mle = optimize(model, MLE()) + +# Use a different optimizer +mle = optimize(model, MLE(), NelderMead()) +``` +""" +function Optim.optimize(model::Model, ::MLE, args...; kwargs...) + ctx = OptimizationContext(DynamicPPL.LikelihoodContext()) + return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) +end + +""" + Optim.optimize(model::Model, ::MAP, args...; kwargs...) + +Compute a maximum a posterior estimate of the `model`. + +# Examples + +```julia-repl +@model function f(x) + m ~ Normal(0, 1) + x ~ Normal(m, 1) +end + +model = f(1.5) +map_est = optimize(model, MAP()) + +# Use a different optimizer +map_est = optimize(model, MAP(), NelderMead()) +``` +""" +function Optim.optimize(model::Model, ::MAP, args...; kwargs...) + ctx = OptimizationContext(DynamicPPL.DefaultContext()) + return optimize(model, OptimLogDensity(model, ctx), args...; kwargs...) +end + +""" + Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) +0 +Estimate a mode, i.e., compute a MLE or MAP estimate. +""" +function Optim.optimize(model::Model, f::OptimLogDensity, optimizer=Optim.LBFGS(), args...; kwargs...) + # Do some initialization. + spl = DynamicPPL.SampleFromPrior() + init_vals = f.vi[spl] + + # Optimize! + M = Optim.optimize(f, init_vals, optimizer, args...; kwargs...) + + # Warn the user if the optimization did not converge. + if !Optim.converged(M) + @warn "Optimization did not converge! You may need to correct your model or adjust the Optim parameters." + end + + # Get the VarInfo at the MLE/MAP point, and run the model to ensure + # correct dimensionality. + f.vi[spl] = M.minimizer + invlink!(f.vi, spl) + vals = f.vi[spl] + link!(f.vi, spl) + + # Make one transition to get the parameter names. + ts = [Turing.Inference.Transition(DynamicPPL.tonamedtuple(f.vi), DynamicPPL.getlogp(f.vi))] + varnames, _ = Turing.Inference._params_to_array(ts) + + # Store the parameters and their names in an array. + vmat = NamedArrays.NamedArray(vals, varnames) + + return ModeResult(vmat, M, -M.minimum, f) +end diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl new file mode 100644 index 000000000..016f1b8a6 --- /dev/null +++ b/test/modes/ModeEstimation.jl @@ -0,0 +1,105 @@ +using Turing +using Optim +using Test +using StatsBase +using NamedArrays +using ReverseDiff +using Random +using LinearAlgebra + +dir = splitdir(splitdir(pathof(Turing))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") + +@testset "ModeEstimation.jl" begin + @testset "MLE" begin + Random.seed!(222) + + m1 = optimize(gdemo_default, MLE()) + m2 = optimize(gdemo_default, MLE(), NelderMead()) + m3 = optimize(gdemo_default, MLE(), Newton()) + + true_value = [0.0625031, 1.75] + @test all(isapprox.(m1.values.array - true_value, 0.0, atol=0.01)) + @test all(isapprox.(m2.values.array - true_value, 0.0, atol=0.01)) + @test all(isapprox.(m3.values.array - true_value, 0.0, atol=0.01)) + end + + @testset "MAP" begin + Random.seed!(222) + + m1 = optimize(gdemo_default, MAP()) + m2 = optimize(gdemo_default, MAP(), NelderMead()) + m3 = optimize(gdemo_default, MAP(), Newton()) + + true_value = [49 / 54, 7 / 6] + @test all(isapprox.(m1.values.array - true_value, 0.0, atol=0.01)) + @test all(isapprox.(m2.values.array - true_value, 0.0, atol=0.01)) + @test all(isapprox.(m3.values.array - true_value, 0.0, atol=0.01)) + end + + @testset "StatsBase integration" begin + Random.seed!(54321) + mle_est = optimize(gdemo_default, MLE()) + + @test coefnames(mle_est) == ["s", "m"] + + diffs = coef(mle_est).array - [0.0625031; 1.75001] + @test all(isapprox.(diffs, 0.0, atol=0.1)) + + infomat = [0.003907027690416608 4.157954948417027e-7; 4.157954948417027e-7 0.03125155528962335] + @test all(isapprox.(infomat - informationmatrix(mle_est), 0.0, atol=0.01)) + + ctable = coeftable(mle_est) + @test ctable isa StatsBase.CoefTable + + s = stderror(mle_est).array + @test all(isapprox.(s - [0.06250415643292194, 0.17677963626053916], 0.0, atol=0.01)) + + @test coefnames(mle_est) == params(mle_est) + @test vcov(mle_est) == informationmatrix(mle_est) + + @test isapprox(loglikelihood(mle_est), -0.0652883561466624, atol=0.01) + end + + @testset "Linear regression test" begin + @model function regtest(x, y) + beta ~ MvNormal(2,1) + mu = x*beta + y ~ MvNormal(mu, 1.0) + end + + Random.seed!(987) + true_beta = [1.0, -2.2] + x = rand(40, 2) + y = x*true_beta + + model = regtest(x, y) + mle = optimize(model, MLE()) + + vcmat = inv(x'x) + vcmat_mle = informationmatrix(mle).array + + @test isapprox(mle.values.array, true_beta) + @test isapprox(vcmat, vcmat_mle) + end + + @testset "Dot tilde test" begin + @model function dot_gdemo(x) + s ~ InverseGamma(2,3) + m ~ Normal(0, sqrt(s)) + + (.~)(x, Normal(m, sqrt(s))) + end + + model_dot = dot_gdemo([1.5, 2.0]) + + mle1 = optimize(gdemo_default, MLE()) + mle2 = optimize(model_dot, MLE()) + + map1 = optimize(gdemo_default, MAP()) + map2 = optimize(model_dot, MAP()) + + @test isapprox(mle1.values.array, mle2.values.array) + @test isapprox(map1.values.array, map2.values.array) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5776651bb..3223d6207 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,4 +53,8 @@ include("test_utils/AllUtils.jl") @testset "utilities" begin # include("utilities/stan-interface.jl") end + + @testset "modes" begin + include("modes/ModeEstimation.jl") + end end