Skip to content

Commit 7c8edab

Browse files
committed
Fix initialization of parameters for algorithms that use SampleFromUniform (#232)
This PR is a quick fix for TuringLang/Turing.jl#1563 and TuringLang/Turing.jl#1588. As explained in TuringLang/Turing.jl#1588 (comment), the problem is that currently `SampleFromUniform` always resamples variables in every run, and hence also initial parameters provided by users are resampled in https://github.com/TuringLang/DynamicPPL.jl/blob/9d4137eb33e83f34c484bf78f9a57f828b3c92a0/src/sampler.jl#L80. As mentioned in TuringLang/Turing.jl#1588 (comment), a better long term solution would be to fix this inconsistency and use dedicated evaluation and sampling contexts, as suggested in #80.
1 parent 9d4137e commit 7c8edab

File tree

3 files changed

+74
-58
lines changed

3 files changed

+74
-58
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.10.15"
3+
version = "0.10.16"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/sampler.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,15 @@ function AbstractMCMC.step(
7777
initialize_parameters!(vi, kwargs[:init_params], spl)
7878

7979
# Update joint log probability.
80-
model(rng, vi, _spl)
80+
# TODO: fix properly by using sampler and evaluation contexts
81+
# This is a quick fix for https://github.com/TuringLang/Turing.jl/issues/1588
82+
# and https://github.com/TuringLang/Turing.jl/issues/1563
83+
# to avoid that existing variables are resampled
84+
if _spl isa SampleFromUniform
85+
model(rng, vi, SampleFromPrior())
86+
else
87+
model(rng, vi, _spl)
88+
end
8189
end
8290

8391
return initialstep(rng, model, spl, vi; kwargs...)

test/sampler.jl

Lines changed: 64 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -32,76 +32,84 @@
3232
end
3333
@testset "Initial parameters" begin
3434
# dummy algorithm that just returns initial value and does not perform any sampling
35-
struct OnlyInitAlg end
35+
abstract type OnlyInitAlg end
36+
struct OnlyInitAlgDefault <: OnlyInitAlg end
37+
struct OnlyInitAlgUniform <: OnlyInitAlg end
3638
function DynamicPPL.initialstep(
3739
rng::Random.AbstractRNG,
3840
model::Model,
39-
::Sampler{OnlyInitAlg},
41+
::Sampler{<:OnlyInitAlg},
4042
vi::AbstractVarInfo;
4143
kwargs...,
4244
)
4345
return vi, nothing
4446
end
45-
DynamicPPL.getspace(::Sampler{OnlyInitAlg}) = ()
47+
DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = ()
4648

47-
# model with one variable: initialization p = 0.2
48-
@model function coinflip()
49-
p ~ Beta(1, 1)
50-
10 ~ Binomial(25, p)
51-
end
52-
model = coinflip()
53-
sampler = Sampler(OnlyInitAlg())
54-
lptrue = logpdf(Binomial(25, 0.2), 10)
55-
chain = sample(model, sampler, 1; init_params = 0.2, progress = false)
56-
@test chain[1].metadata.p.vals == [0.2]
57-
@test getlogp(chain[1]) == lptrue
49+
# initial samplers
50+
DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform()
51+
@test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior()
5852

59-
# parallel sampling
60-
chains = sample(
61-
model, sampler, MCMCThreads(), 1, 10;
62-
init_params = 0.2, progress = false,
63-
)
64-
for c in chains
65-
@test c[1].metadata.p.vals == [0.2]
66-
@test getlogp(c[1]) == lptrue
67-
end
53+
for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform())
54+
# model with one variable: initialization p = 0.2
55+
@model function coinflip()
56+
p ~ Beta(1, 1)
57+
10 ~ Binomial(25, p)
58+
end
59+
model = coinflip()
60+
sampler = Sampler(alg)
61+
lptrue = logpdf(Binomial(25, 0.2), 10)
62+
chain = sample(model, sampler, 1; init_params = 0.2, progress = false)
63+
@test chain[1].metadata.p.vals == [0.2]
64+
@test getlogp(chain[1]) == lptrue
6865

69-
# model with two variables: initialization s = 4, m = -1
70-
@model function twovars()
71-
s ~ InverseGamma(2, 3)
72-
m ~ Normal(0, sqrt(s))
73-
end
74-
model = twovars()
75-
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
76-
chain = sample(model, sampler, 1; init_params = [4, -1], progress = false)
77-
@test chain[1].metadata.s.vals == [4]
78-
@test chain[1].metadata.m.vals == [-1]
79-
@test getlogp(chain[1]) == lptrue
66+
# parallel sampling
67+
chains = sample(
68+
model, sampler, MCMCThreads(), 1, 10;
69+
init_params = 0.2, progress = false,
70+
)
71+
for c in chains
72+
@test c[1].metadata.p.vals == [0.2]
73+
@test getlogp(c[1]) == lptrue
74+
end
8075

81-
# parallel sampling
82-
chains = sample(
83-
model, sampler, MCMCThreads(), 1, 10;
84-
init_params = [4, -1], progress = false,
85-
)
86-
for c in chains
87-
@test c[1].metadata.s.vals == [4]
88-
@test c[1].metadata.m.vals == [-1]
89-
@test getlogp(c[1]) == lptrue
90-
end
76+
# model with two variables: initialization s = 4, m = -1
77+
@model function twovars()
78+
s ~ InverseGamma(2, 3)
79+
m ~ Normal(0, sqrt(s))
80+
end
81+
model = twovars()
82+
lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1)
83+
chain = sample(model, sampler, 1; init_params = [4, -1], progress = false)
84+
@test chain[1].metadata.s.vals == [4]
85+
@test chain[1].metadata.m.vals == [-1]
86+
@test getlogp(chain[1]) == lptrue
87+
88+
# parallel sampling
89+
chains = sample(
90+
model, sampler, MCMCThreads(), 1, 10;
91+
init_params = [4, -1], progress = false,
92+
)
93+
for c in chains
94+
@test c[1].metadata.s.vals == [4]
95+
@test c[1].metadata.m.vals == [-1]
96+
@test getlogp(c[1]) == lptrue
97+
end
9198

92-
# set only m = -1
93-
chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false)
94-
@test !ismissing(chain[1].metadata.s.vals[1])
95-
@test chain[1].metadata.m.vals == [-1]
99+
# set only m = -1
100+
chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false)
101+
@test !ismissing(chain[1].metadata.s.vals[1])
102+
@test chain[1].metadata.m.vals == [-1]
96103

97-
# parallel sampling
98-
chains = sample(
99-
model, sampler, MCMCThreads(), 1, 10;
100-
init_params = [missing, -1], progress = false,
101-
)
102-
for c in chains
103-
@test !ismissing(c[1].metadata.s.vals[1])
104-
@test c[1].metadata.m.vals == [-1]
104+
# parallel sampling
105+
chains = sample(
106+
model, sampler, MCMCThreads(), 1, 10;
107+
init_params = [missing, -1], progress = false,
108+
)
109+
for c in chains
110+
@test !ismissing(c[1].metadata.s.vals[1])
111+
@test c[1].metadata.m.vals == [-1]
112+
end
105113
end
106114
end
107115
end

0 commit comments

Comments
 (0)