Skip to content

Commit a1bf714

Browse files
JaimeRZPdevmotion
andauthored
Unify transition also in external samplers (#2030)
* Transition * Revert "Transition" This reverts commit 71c8097. * bug * repeated functions * move Transition to inference * default get_stat * bug * Update src/inference/Inference.jl Co-authored-by: David Widmann <[email protected]> * Update src/inference/Inference.jl Co-authored-by: David Widmann <[email protected]> * Update src/inference/Inference.jl Co-authored-by: David Widmann <[email protected]> * rest of david changes * bring back Transition(a,b) --------- Co-authored-by: David Widmann <[email protected]>
1 parent e41f58c commit a1bf714

File tree

3 files changed

+7
-22
lines changed

3 files changed

+7
-22
lines changed

Project.toml

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3939
[compat]
4040
AbstractMCMC = "4"
4141
AdvancedHMC = "0.3.0, 0.4"
42-
AdvancedMH = "0.6.8, 0.7"
4342
AdvancedPS = "0.4"
4443
AdvancedVI = "0.2"
4544
BangBang = "0.3"

src/contrib/inference/abstractmcmc.jl

+1-16
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,13 @@ struct TuringState{S,F}
33
logdensity::F
44
end
55

6-
struct TuringTransition{T,NT<:NamedTuple,F<:AbstractFloat}
7-
θ::T
8-
lp::F
9-
stat::NT
10-
end
11-
12-
function TuringTransition(vi::AbstractVarInfo, t)
13-
theta = tonamedtuple(vi)
14-
lp = getlogp(vi)
15-
return TuringTransition(theta, lp, getstats(t))
16-
end
17-
18-
metadata(t::TuringTransition) = merge((lp = t.lp,), t.stat)
19-
DynamicPPL.getlogp(t::TuringTransition) = t.lp
20-
216
state_to_turing(f::DynamicPPL.LogDensityFunction, state) = TuringState(state, f)
227
function transition_to_turing(f::DynamicPPL.LogDensityFunction, transition)
238
θ = getparams(transition)
249
varinfo = DynamicPPL.unflatten(f.varinfo, θ)
2510
# TODO: `deepcopy` is overkill; make more efficient.
2611
varinfo = DynamicPPL.invlink!!(deepcopy(varinfo), f.model)
27-
return TuringTransition(varinfo, transition)
12+
return Transition(varinfo, transition)
2813
end
2914

3015
# NOTE: Only thing that depends on the underlying sampler.

src/inference/Inference.jl

+6-5
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ end
123123
######################
124124
# Default Transition #
125125
######################
126+
# Default
127+
# Extended in contrib/inference/abstractmcmc.jl
128+
getstats(t) = nothing
126129

127130
struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
128131
θ :: T
@@ -132,10 +135,10 @@ end
132135

133136
Transition(θ, lp) = Transition(θ, lp, nothing)
134137

135-
function Transition(vi::AbstractVarInfo; nt::NamedTuple=NamedTuple())
138+
function Transition(vi::AbstractVarInfo, t=nothing; nt::NamedTuple=NamedTuple())
136139
θ = merge(tonamedtuple(vi), nt)
137140
lp = getlogp(vi)
138-
return Transition(θ, lp, nothing)
141+
return Transition(θ, lp, getstats(t))
139142
end
140143

141144
function metadata(t::Transition)
@@ -664,9 +667,7 @@ function transitions_from_chain(
664667
model(rng, vi, sampler)
665668

666669
# Convert `VarInfo` into `NamedTuple` and save.
667-
theta = DynamicPPL.tonamedtuple(vi)
668-
lp = Turing.getlogp(vi)
669-
Transition(theta, lp)
670+
Transition(vi)
670671
end
671672

672673
return transitions

0 commit comments

Comments
 (0)