Skip to content

Commit 4a986df

Browse files
Added extension for MCMCChains (#514)
* added extension for MCMCChains * bump patch versoin * added entry to .gitignore * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * forgot to add MCMCChains to extras * Update ext/DynamicPPLMCMCChainsExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 1ebe8bc commit 4a986df

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
*.jl.mem
44
.DS_Store
55
Manifest.toml
6+
**.~undo-tree~

Diff for: Project.toml

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.23.11"
3+
version = "0.23.12"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -20,6 +20,15 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2222

23+
[weakdeps]
24+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
25+
26+
[extensions]
27+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
28+
29+
[extras]
30+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
31+
2332
[compat]
2433
AbstractMCMC = "2, 3.0, 4"
2534
AbstractPPL = "0.6"
@@ -31,6 +40,7 @@ Distributions = "0.23.8, 0.24, 0.25"
3140
DocStringExtensions = "0.8, 0.9"
3241
LogDensityProblems = "2"
3342
MacroTools = "0.5.6"
43+
MCMCChains = "6"
3444
OrderedCollections = "1"
3545
Setfield = "0.7.1, 0.8, 1"
3646
ZygoteRules = "0.2"

Diff for: ext/DynamicPPLMCMCChainsExt.jl

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
module DynamicPPLMCMCChainsExt
2+
3+
using DynamicPPL: DynamicPPL
4+
using MCMCChains: MCMCChains
5+
6+
function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains)
7+
chain_parameters = MCMCChains.get_sections(chain, :parameters)
8+
varinfo = DynamicPPL.VarInfo(model)
9+
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
10+
return map(iters) do (sample_idx, chain_idx)
11+
DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx)
12+
model(varinfo)
13+
end
14+
end
15+
16+
end

Diff for: test/ext/DynamicPPLMCMCChainsExt.jl

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@testset "DynamicPPLMCMCChainsExt" begin
2+
@model demo() = x ~ Normal()
3+
model = demo()
4+
5+
chain = MCMCChains.Chains(randn(1000, 2, 1), [:x, :y], Dict(:internals => [:y]))
6+
chain_generated = @test_nowarn generated_quantities(model, chain)
7+
@test size(chain_generated) == (1000, 1)
8+
@test mean(chain_generated) 0 atol = 0.1
9+
end

Diff for: test/runtests.jl

+4
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ include("test_util.jl")
5959
include(joinpath("compat", "ad.jl"))
6060
end
6161

62+
@testset "extensions" begin
63+
include("ext/DynamicPPLMCMCChainsExt.jl")
64+
end
65+
6266
@testset "doctests" begin
6367
DocMeta.setdocmeta!(
6468
DynamicPPL,

0 commit comments

Comments
 (0)