|
32 | 32 | end
|
33 | 33 | @testset "Initial parameters" begin
|
34 | 34 | # dummy algorithm that just returns initial value and does not perform any sampling
|
35 |
| - struct OnlyInitAlg end |
| 35 | + abstract type OnlyInitAlg end |
| 36 | + struct OnlyInitAlgDefault <: OnlyInitAlg end |
| 37 | + struct OnlyInitAlgUniform <: OnlyInitAlg end |
36 | 38 | function DynamicPPL.initialstep(
|
37 | 39 | rng::Random.AbstractRNG,
|
38 | 40 | model::Model,
|
39 |
| - ::Sampler{OnlyInitAlg}, |
| 41 | + ::Sampler{<:OnlyInitAlg}, |
40 | 42 | vi::AbstractVarInfo;
|
41 | 43 | kwargs...,
|
42 | 44 | )
|
43 | 45 | return vi, nothing
|
44 | 46 | end
|
45 |
| - DynamicPPL.getspace(::Sampler{OnlyInitAlg}) = () |
| 47 | + DynamicPPL.getspace(::Sampler{<:OnlyInitAlg}) = () |
46 | 48 |
|
47 |
| - # model with one variable: initialization p = 0.2 |
48 |
| - @model function coinflip() |
49 |
| - p ~ Beta(1, 1) |
50 |
| - 10 ~ Binomial(25, p) |
51 |
| - end |
52 |
| - model = coinflip() |
53 |
| - sampler = Sampler(OnlyInitAlg()) |
54 |
| - lptrue = logpdf(Binomial(25, 0.2), 10) |
55 |
| - chain = sample(model, sampler, 1; init_params = 0.2, progress = false) |
56 |
| - @test chain[1].metadata.p.vals == [0.2] |
57 |
| - @test getlogp(chain[1]) == lptrue |
| 49 | + # initial samplers |
| 50 | + DynamicPPL.initialsampler(::Sampler{OnlyInitAlgUniform}) = SampleFromUniform() |
| 51 | + @test DynamicPPL.initialsampler(Sampler(OnlyInitAlgDefault())) == SampleFromPrior() |
58 | 52 |
|
59 |
| - # parallel sampling |
60 |
| - chains = sample( |
61 |
| - model, sampler, MCMCThreads(), 1, 10; |
62 |
| - init_params = 0.2, progress = false, |
63 |
| - ) |
64 |
| - for c in chains |
65 |
| - @test c[1].metadata.p.vals == [0.2] |
66 |
| - @test getlogp(c[1]) == lptrue |
67 |
| - end |
| 53 | + for alg in (OnlyInitAlgDefault(), OnlyInitAlgUniform()) |
| 54 | + # model with one variable: initialization p = 0.2 |
| 55 | + @model function coinflip() |
| 56 | + p ~ Beta(1, 1) |
| 57 | + 10 ~ Binomial(25, p) |
| 58 | + end |
| 59 | + model = coinflip() |
| 60 | + sampler = Sampler(alg) |
| 61 | + lptrue = logpdf(Binomial(25, 0.2), 10) |
| 62 | + chain = sample(model, sampler, 1; init_params = 0.2, progress = false) |
| 63 | + @test chain[1].metadata.p.vals == [0.2] |
| 64 | + @test getlogp(chain[1]) == lptrue |
68 | 65 |
|
69 |
| - # model with two variables: initialization s = 4, m = -1 |
70 |
| - @model function twovars() |
71 |
| - s ~ InverseGamma(2, 3) |
72 |
| - m ~ Normal(0, sqrt(s)) |
73 |
| - end |
74 |
| - model = twovars() |
75 |
| - lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) |
76 |
| - chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) |
77 |
| - @test chain[1].metadata.s.vals == [4] |
78 |
| - @test chain[1].metadata.m.vals == [-1] |
79 |
| - @test getlogp(chain[1]) == lptrue |
| 66 | + # parallel sampling |
| 67 | + chains = sample( |
| 68 | + model, sampler, MCMCThreads(), 1, 10; |
| 69 | + init_params = 0.2, progress = false, |
| 70 | + ) |
| 71 | + for c in chains |
| 72 | + @test c[1].metadata.p.vals == [0.2] |
| 73 | + @test getlogp(c[1]) == lptrue |
| 74 | + end |
80 | 75 |
|
81 |
| - # parallel sampling |
82 |
| - chains = sample( |
83 |
| - model, sampler, MCMCThreads(), 1, 10; |
84 |
| - init_params = [4, -1], progress = false, |
85 |
| - ) |
86 |
| - for c in chains |
87 |
| - @test c[1].metadata.s.vals == [4] |
88 |
| - @test c[1].metadata.m.vals == [-1] |
89 |
| - @test getlogp(c[1]) == lptrue |
90 |
| - end |
| 76 | + # model with two variables: initialization s = 4, m = -1 |
| 77 | + @model function twovars() |
| 78 | + s ~ InverseGamma(2, 3) |
| 79 | + m ~ Normal(0, sqrt(s)) |
| 80 | + end |
| 81 | + model = twovars() |
| 82 | + lptrue = logpdf(InverseGamma(2, 3), 4) + logpdf(Normal(0, 2), -1) |
| 83 | + chain = sample(model, sampler, 1; init_params = [4, -1], progress = false) |
| 84 | + @test chain[1].metadata.s.vals == [4] |
| 85 | + @test chain[1].metadata.m.vals == [-1] |
| 86 | + @test getlogp(chain[1]) == lptrue |
| 87 | + |
| 88 | + # parallel sampling |
| 89 | + chains = sample( |
| 90 | + model, sampler, MCMCThreads(), 1, 10; |
| 91 | + init_params = [4, -1], progress = false, |
| 92 | + ) |
| 93 | + for c in chains |
| 94 | + @test c[1].metadata.s.vals == [4] |
| 95 | + @test c[1].metadata.m.vals == [-1] |
| 96 | + @test getlogp(c[1]) == lptrue |
| 97 | + end |
91 | 98 |
|
92 |
| - # set only m = -1 |
93 |
| - chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) |
94 |
| - @test !ismissing(chain[1].metadata.s.vals[1]) |
95 |
| - @test chain[1].metadata.m.vals == [-1] |
| 99 | + # set only m = -1 |
| 100 | + chain = sample(model, sampler, 1; init_params = [missing, -1], progress = false) |
| 101 | + @test !ismissing(chain[1].metadata.s.vals[1]) |
| 102 | + @test chain[1].metadata.m.vals == [-1] |
96 | 103 |
|
97 |
| - # parallel sampling |
98 |
| - chains = sample( |
99 |
| - model, sampler, MCMCThreads(), 1, 10; |
100 |
| - init_params = [missing, -1], progress = false, |
101 |
| - ) |
102 |
| - for c in chains |
103 |
| - @test !ismissing(c[1].metadata.s.vals[1]) |
104 |
| - @test c[1].metadata.m.vals == [-1] |
| 104 | + # parallel sampling |
| 105 | + chains = sample( |
| 106 | + model, sampler, MCMCThreads(), 1, 10; |
| 107 | + init_params = [missing, -1], progress = false, |
| 108 | + ) |
| 109 | + for c in chains |
| 110 | + @test !ismissing(c[1].metadata.s.vals[1]) |
| 111 | + @test c[1].metadata.m.vals == [-1] |
| 112 | + end |
105 | 113 | end
|
106 | 114 | end
|
107 | 115 | end
|
0 commit comments