Skip to content

Commit 59e5cce

Browse files
DominiqueMakowskidevmotionyebai
authored
New Feature: Fix and improve coeftable for otpimize() output (#2034)
* Fix #2033 * Propose additional columns * Add other suggestions * Update src/modes/OptimInterface.jl Co-authored-by: David Widmann <[email protected]> * Update src/modes/OptimInterface.jl Co-authored-by: David Widmann <[email protected]> * Update src/modes/OptimInterface.jl Co-authored-by: David Widmann <[email protected]> * bump version * import SsStatsAPI * Update Project.toml * Update src/modes/OptimInterface.jl Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]>
1 parent 90e1d21 commit 59e5cce

File tree

2 files changed

+59
-47
lines changed

2 files changed

+59
-47
lines changed

Project.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.26.3"
3+
version = "0.26.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -32,6 +32,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3232
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
3333
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
3434
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
35+
StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
3536
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3637
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
3738
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
@@ -62,6 +63,7 @@ Requires = "0.5, 1.0"
6263
SciMLBase = "1.37.1"
6364
Setfield = "0.8, 1"
6465
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10, 1, 2"
66+
StatsAPI = "1.6"
6567
StatsBase = "0.32, 0.33, 0.34"
6668
StatsFuns = "0.8, 0.9, 1"
6769
Tracker = "0.2.3"

src/modes/OptimInterface.jl

+56-46
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,32 @@ import ..ForwardDiff
77
import NamedArrays
88
import StatsBase
99
import Printf
10+
import StatsAPI
1011

1112

1213
"""
1314
ModeResult{
14-
V<:NamedArrays.NamedArray,
15-
M<:NamedArrays.NamedArray,
16-
O<:Optim.MultivariateOptimizationResults,
15+
V<:NamedArrays.NamedArray,
16+
M<:NamedArrays.NamedArray,
17+
O<:Optim.MultivariateOptimizationResults,
1718
S<:NamedArrays.NamedArray
1819
}
1920
2021
A wrapper struct to store various results from a MAP or MLE estimation.
2122
"""
2223
struct ModeResult{
23-
V<:NamedArrays.NamedArray,
24+
V<:NamedArrays.NamedArray,
2425
O<:Optim.MultivariateOptimizationResults,
2526
M<:OptimLogDensity
2627
} <: StatsBase.StatisticalModel
2728
"A vector with the resulting point estimates."
28-
values :: V
29+
values::V
2930
"The stored Optim.jl results."
30-
optim_result :: O
31+
optim_result::O
3132
"The final log likelihood or log joint, depending on whether `MAP` or `MLE` was run."
32-
lp :: Float64
33+
lp::Float64
3334
"The evaluation function used to calculate the output."
34-
f :: M
35+
f::M
3536
end
3637
#############################
3738
# Various StatsBase methods #
@@ -50,14 +51,23 @@ function Base.show(io::IO, m::ModeResult)
5051
show(io, m.values.array)
5152
end
5253

53-
function StatsBase.coeftable(m::ModeResult)
54+
function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
5455
# Get columns for coeftable.
55-
terms = StatsBase.coefnames(m)
56-
estimates = m.values.array[:,1]
56+
terms = string.(StatsBase.coefnames(m))
57+
estimates = m.values.array[:, 1]
5758
stderrors = StatsBase.stderror(m)
58-
tstats = estimates ./ stderrors
59-
60-
StatsBase.CoefTable([estimates, stderrors, tstats], ["estimate", "stderror", "tstat"], terms)
59+
zscore = estimates ./ stderrors
60+
p = map(z -> StatsAPI.pvalue(Normal(), z; tail=:both), zscore)
61+
62+
# Confidence interval (CI)
63+
q = quantile(Normal(), (1 + level) / 2)
64+
ci_low = estimates .- q .* stderrors
65+
ci_high = estimates .+ q .* stderrors
66+
67+
StatsBase.CoefTable(
68+
[estimates, stderrors, zscore, p, ci_low, ci_high],
69+
["Coef.", "Std. Error", "z", "Pr(>|z|)", "Lower 95%", "Upper 95%"],
70+
terms)
6171
end
6272

6373
function StatsBase.informationmatrix(m::ModeResult; hessian_function=ForwardDiff.hessian, kwargs...)
@@ -113,7 +123,7 @@ mle = optimize(model, MLE())
113123
mle = optimize(model, MLE(), NelderMead())
114124
```
115125
"""
116-
function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...)
126+
function Optim.optimize(model::Model, ::MLE, options::Optim.Options=Optim.Options(); kwargs...)
117127
return _mle_optimize(model, options; kwargs...)
118128
end
119129
function Optim.optimize(model::Model, ::MLE, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
@@ -123,11 +133,11 @@ function Optim.optimize(model::Model, ::MLE, optimizer::Optim.AbstractOptimizer,
123133
return _mle_optimize(model, optimizer, options; kwargs...)
124134
end
125135
function Optim.optimize(
126-
model::Model,
127-
::MLE,
128-
init_vals::AbstractArray,
129-
optimizer::Optim.AbstractOptimizer,
130-
options::Optim.Options=Optim.Options();
136+
model::Model,
137+
::MLE,
138+
init_vals::AbstractArray,
139+
optimizer::Optim.AbstractOptimizer,
140+
options::Optim.Options=Optim.Options();
131141
kwargs...
132142
)
133143
return _mle_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -159,7 +169,7 @@ map_est = optimize(model, MAP(), NelderMead())
159169
```
160170
"""
161171

162-
function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...)
172+
function Optim.optimize(model::Model, ::MAP, options::Optim.Options=Optim.Options(); kwargs...)
163173
return _map_optimize(model, options; kwargs...)
164174
end
165175
function Optim.optimize(model::Model, ::MAP, init_vals::AbstractArray, options::Optim.Options=Optim.Options(); kwargs...)
@@ -169,11 +179,11 @@ function Optim.optimize(model::Model, ::MAP, optimizer::Optim.AbstractOptimizer,
169179
return _map_optimize(model, optimizer, options; kwargs...)
170180
end
171181
function Optim.optimize(
172-
model::Model,
173-
::MAP,
174-
init_vals::AbstractArray,
175-
optimizer::Optim.AbstractOptimizer,
176-
options::Optim.Options=Optim.Options();
182+
model::Model,
183+
::MAP,
184+
init_vals::AbstractArray,
185+
optimizer::Optim.AbstractOptimizer,
186+
options::Optim.Options=Optim.Options();
177187
kwargs...
178188
)
179189
return _map_optimize(model, init_vals, optimizer, options; kwargs...)
@@ -190,43 +200,43 @@ end
190200
Estimate a mode, i.e., compute a MLE or MAP estimate.
191201
"""
192202
function _optimize(
193-
model::Model,
194-
f::OptimLogDensity,
195-
optimizer::Optim.AbstractOptimizer = Optim.LBFGS(),
196-
args...;
203+
model::Model,
204+
f::OptimLogDensity,
205+
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
206+
args...;
197207
kwargs...
198208
)
199209
return _optimize(model, f, DynamicPPL.getparams(f), optimizer, args...; kwargs...)
200210
end
201211

202212
function _optimize(
203-
model::Model,
204-
f::OptimLogDensity,
205-
options::Optim.Options = Optim.Options(),
206-
args...;
213+
model::Model,
214+
f::OptimLogDensity,
215+
options::Optim.Options=Optim.Options(),
216+
args...;
207217
kwargs...
208218
)
209219
return _optimize(model, f, DynamicPPL.getparams(f), Optim.LBFGS(), args...; kwargs...)
210220
end
211221

212222
function _optimize(
213-
model::Model,
214-
f::OptimLogDensity,
215-
init_vals::AbstractArray = DynamicPPL.getparams(f),
216-
options::Optim.Options = Optim.Options(),
217-
args...;
223+
model::Model,
224+
f::OptimLogDensity,
225+
init_vals::AbstractArray=DynamicPPL.getparams(f),
226+
options::Optim.Options=Optim.Options(),
227+
args...;
218228
kwargs...
219229
)
220230
return _optimize(model, f, init_vals, Optim.LBFGS(), options, args...; kwargs...)
221231
end
222232

223233
function _optimize(
224-
model::Model,
225-
f::OptimLogDensity,
226-
init_vals::AbstractArray = DynamicPPL.getparams(f),
227-
optimizer::Optim.AbstractOptimizer = Optim.LBFGS(),
228-
options::Optim.Options = Optim.Options(),
229-
args...;
234+
model::Model,
235+
f::OptimLogDensity,
236+
init_vals::AbstractArray=DynamicPPL.getparams(f),
237+
optimizer::Optim.AbstractOptimizer=Optim.LBFGS(),
238+
options::Optim.Options=Optim.Options(),
239+
args...;
230240
kwargs...
231241
)
232242
# Convert the initial values, since it is assumed that users provide them
@@ -243,7 +253,7 @@ function _optimize(
243253
@warn "Optimization did not converge! You may need to correct your model or adjust the Optim parameters."
244254
end
245255

246-
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
256+
# Get the VarInfo at the MLE/MAP point, and run the model to ensure
247257
# correct dimensionality.
248258
@set! f.varinfo = DynamicPPL.unflatten(f.varinfo, M.minimizer)
249259
@set! f.varinfo = invlink!!(f.varinfo, model)

0 commit comments

Comments
 (0)