-
Notifications
You must be signed in to change notification settings - Fork 225
/
Copy pathess.jl
127 lines (107 loc) · 3.87 KB
/
ess.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
ESS
Elliptical slice sampling algorithm.
# Examples
```jldoctest; setup = :(Random.seed!(1))
julia> @model function gdemo(x)
m ~ Normal()
x ~ Normal(m, 0.5)
end
gdemo (generic function with 2 methods)
julia> sample(gdemo(1.0), ESS(), 1_000) |> mean
Mean
│ Row │ parameters │ mean │
│ │ Symbol │ Float64 │
├─────┼────────────┼──────────┤
│ 1 │ m │ 0.824853 │
```
"""
struct ESS <: InferenceAlgorithm end
# always accept in the first step
function DynamicPPL.initialstep(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
)
for vn in keys(vi)
dist = getdist(vi, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("ESS only supports Gaussian prior distributions")
end
return Transition(model, vi), vi
end
function AbstractMCMC.step(
rng::AbstractRNG, model::Model, spl::Sampler{<:ESS}, vi::AbstractVarInfo; kwargs...
)
# obtain previous sample
f = vi[:]
# define previous sampler state
# (do not use cache to avoid in-place sampling from prior)
oldstate = EllipticalSliceSampling.ESSState(f, getlogp(vi), nothing)
# compute next state
sample, state = AbstractMCMC.step(
rng,
EllipticalSliceSampling.ESSModel(
ESSPrior(model, spl, vi),
DynamicPPL.LogDensityFunction(
model, vi, DynamicPPL.SamplingContext(spl, DynamicPPL.DefaultContext())
),
),
EllipticalSliceSampling.ESS(),
oldstate,
)
# update sample and log-likelihood
vi = DynamicPPL.unflatten(vi, sample)
vi = setlogp!!(vi, state.loglikelihood)
return Transition(model, vi), vi
end
# Prior distribution of considered random variable
struct ESSPrior{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo,T}
model::M
sampler::S
varinfo::V
μ::T
function ESSPrior{M,S,V}(
model::M, sampler::S, varinfo::V
) where {M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo}
vns = keys(varinfo)
μ = mapreduce(vcat, vns) do vn
dist = getdist(varinfo, vn)
EllipticalSliceSampling.isgaussian(typeof(dist)) ||
error("[ESS] only supports Gaussian prior distributions")
DynamicPPL.tovec(mean(dist))
end
return new{M,S,V,typeof(μ)}(model, sampler, varinfo, μ)
end
end
function ESSPrior(model::Model, sampler::Sampler{<:ESS}, varinfo::AbstractVarInfo)
return ESSPrior{typeof(model),typeof(sampler),typeof(varinfo)}(model, sampler, varinfo)
end
# Ensure that the prior is a Gaussian distribution (checked in the constructor)
EllipticalSliceSampling.isgaussian(::Type{<:ESSPrior}) = true
# Only define out-of-place sampling
function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
# TODO: Surely there's a better way of doing this now that we have `SamplingContext`?
vns = keys(varinfo)
for vn in vns
set_flag!(varinfo, vn, "del")
end
p.model(rng, varinfo, sampler)
return varinfo[:]
end
# Mean of prior distribution
Distributions.mean(p::ESSPrior) = p.μ
# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} =
DynamicPPL.LogDensityFunction{M,V,<:DynamicPPL.SamplingContext{<:S},AD} where {AD}
(ℓ::ESSLogLikelihood)(f::AbstractVector) = LogDensityProblems.logdensity(ℓ, f)
function DynamicPPL.tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, ::Sampler{<:ESS}, right, vn, vi
)
return DynamicPPL.tilde_assume(
rng, LikelihoodContext(), SampleFromPrior(), right, vn, vi
)
end
function DynamicPPL.tilde_observe(ctx::DefaultContext, ::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end