Skip to content

Commit 43e2f20

Browse files
authored
Simplify Turing.Variational and fix test error (#1377)
1 parent 3931d47 commit 43e2f20

File tree

7 files changed

+68
-177
lines changed

7 files changed

+68
-177
lines changed

Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
2020
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2121
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
2222
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
23-
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
2423
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2524
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
2625
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
@@ -46,7 +45,6 @@ Libtask = "0.4"
4645
LogDensityProblems = "^0.9, 0.10"
4746
MCMCChains = "4"
4847
NamedArrays = "0.9"
49-
ProgressLogging = "0.1"
5048
Reexport = "0.2.0"
5149
Requires = "0.5, 1.0"
5250
SpecialFunctions = "0.7.2, 0.8, 0.9, 0.10"

src/Turing.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@ using Libtask
1515
@reexport using Distributions, MCMCChains, Libtask, AbstractMCMC, Bijectors
1616
using Tracker: Tracker
1717

18+
import AdvancedVI
1819
import DynamicPPL: getspace, NoDist, NamedDist
1920

2021
const PROGRESS = Ref(true)
2122
function turnprogress(switch::Bool)
2223
@info "[Turing]: progress logging is $(switch ? "enabled" : "disabled") globally"
2324
PROGRESS[] = switch
25+
AdvancedVI.turnprogress(switch)
2426
end
2527

2628
# Random probability measures.

src/variational/VariationalInference.jl

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,16 @@
11
module Variational
22

3-
using ..Core, ..Utilities
4-
using DocStringExtensions: TYPEDEF, TYPEDFIELDS
5-
using Distributions, Bijectors, DynamicPPL
6-
using LinearAlgebra
7-
using ..Turing: PROGRESS, Turing
8-
using DynamicPPL: Model, SampleFromPrior, SampleFromUniform
9-
using Random: AbstractRNG
3+
import AdvancedVI
4+
import Bijectors
5+
import DistributionsAD
6+
import DynamicPPL
7+
import StatsBase
8+
import StatsFuns
109

11-
using ForwardDiff
12-
using Tracker
13-
14-
import ..Core: getchunksize, getADbackend
15-
16-
import AbstractMCMC
17-
import ProgressLogging
18-
19-
using AdvancedVI
10+
import Random
2011

12+
# Reexports
13+
using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad
2114
export
2215
vi,
2316
ADVI,
@@ -34,38 +27,38 @@ use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_o
3427
## Notes
3528
- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`.
3629
"""
37-
function make_logjoint(model::Model; weight = 1.0)
30+
function make_logjoint(model::DynamicPPL.Model; weight = 1.0)
3831
# setup
3932
ctx = DynamicPPL.MiniBatchContext(
4033
DynamicPPL.DefaultContext(),
4134
weight
4235
)
43-
varinfo_init = Turing.VarInfo(model, ctx)
36+
varinfo_init = DynamicPPL.VarInfo(model, ctx)
4437

4538
function logπ(z)
46-
varinfo = VarInfo(varinfo_init, SampleFromUniform(), z)
39+
varinfo = DynamicPPL.VarInfo(varinfo_init, DynamicPPL.SampleFromUniform(), z)
4740
model(varinfo)
4841

49-
return getlogp(varinfo)
42+
return DynamicPPL.getlogp(varinfo)
5043
end
5144

5245
return logπ
5346
end
5447

55-
function logjoint(model::Model, varinfo, z)
56-
varinfo = VarInfo(varinfo, SampleFromUniform(), z)
48+
function logjoint(model::DynamicPPL.Model, varinfo, z)
49+
varinfo = DynamicPPL.VarInfo(varinfo, DynamicPPL.SampleFromUniform(), z)
5750
model(varinfo)
5851

59-
return getlogp(varinfo)
52+
return DynamicPPL.getlogp(varinfo)
6053
end
6154

6255

6356
# objectives
6457
function (elbo::ELBO)(
65-
rng::AbstractRNG,
66-
alg::VariationalInference,
58+
rng::Random.AbstractRNG,
59+
alg::AdvancedVI.VariationalInference,
6760
q,
68-
model::Model,
61+
model::DynamicPPL.Model,
6962
num_samples;
7063
weight = 1.0,
7164
kwargs...

src/variational/advi.jl

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
using StatsFuns
2-
using DistributionsAD
3-
using Bijectors
4-
using Bijectors: TransformedDistribution
5-
using Random: AbstractRNG, GLOBAL_RNG
6-
import Bijectors: bijector
7-
81
"""
9-
bijector(model::Model; sym_to_ranges = Val(false))
2+
bijector(model::Model[, sym2ranges = Val(false)])
103
114
Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d`
125
denoting the dimensionality of the latent variables.
136
"""
14-
function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) where {sym2ranges}
15-
varinfo = Turing.VarInfo(model)
7+
function Bijectors.bijector(
8+
model::DynamicPPL.Model,
9+
::Val{sym2ranges} = Val(false),
10+
) where {sym2ranges}
11+
varinfo = DynamicPPL.VarInfo(model)
1612
num_params = sum([size(varinfo.metadata[sym].vals, 1)
1713
for sym keys(varinfo.metadata)])
1814

@@ -37,25 +33,27 @@ function bijector(model::Model; sym_to_ranges::Val{sym2ranges} = Val(false)) whe
3733
idx += varinfo.metadata[sym].ranges[end][end]
3834
end
3935

40-
bs = bijector.(tuple(dists...))
36+
bs = Bijectors.bijector.(tuple(dists...))
4137

4238
if sym2ranges
43-
return Stacked(bs, ranges), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...)
39+
return (
40+
Bijectors.Stacked(bs, ranges),
41+
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
42+
)
4443
else
45-
return Stacked(bs, ranges)
44+
return Bijectors.Stacked(bs, ranges)
4645
end
4746
end
4847

4948
"""
50-
meanfield(model::Model)
51-
meanfield(rng::AbstractRNG, model::Model)
49+
meanfield([rng, ]model::Model)
5250
5351
Creates a mean-field approximation with multivariate normal as underlying distribution.
5452
"""
55-
meanfield(model::Model) = meanfield(GLOBAL_RNG, model)
56-
function meanfield(rng::AbstractRNG, model::Model)
53+
meanfield(model::DynamicPPL.Model) = meanfield(Random.GLOBAL_RNG, model)
54+
function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model)
5755
# setup
58-
varinfo = Turing.VarInfo(model)
56+
varinfo = DynamicPPL.VarInfo(model)
5957
num_params = sum([size(varinfo.metadata[sym].vals, 1)
6058
for sym keys(varinfo.metadata)])
6159

@@ -71,43 +69,58 @@ function meanfield(rng::AbstractRNG, model::Model)
7169
ranges[range_idx] = idx .+ r
7270
range_idx += 1
7371
end
74-
72+
7573
# append!(ranges, [idx .+ r for r ∈ varinfo.metadata[sym].ranges])
7674
idx += varinfo.metadata[sym].ranges[end][end]
7775
end
7876

7977
# initial params
8078
μ = randn(rng, num_params)
81-
σ = softplus.(randn(rng, num_params))
79+
σ = StatsFuns.softplus.(randn(rng, num_params))
8280

8381
# construct variational posterior
84-
d = TuringDiagMvNormal(μ, σ)
85-
bs = inv.(bijector.(tuple(dists...)))
86-
b = Stacked(bs, ranges)
82+
d = DistributionsAD.TuringDiagMvNormal(μ, σ)
83+
bs = inv.(Bijectors.bijector.(tuple(dists...)))
84+
b = Bijectors.Stacked(bs, ranges)
8785

88-
return transformed(d, b)
86+
return Bijectors.transformed(d, b)
8987
end
9088

91-
9289
# Overloading stuff from `AdvancedVI` to specialize for Turing
93-
AdvancedVI.update(d::TuringDiagMvNormal, μ, σ) = TuringDiagMvNormal(μ, σ)
94-
AdvancedVI.update(td::TransformedDistribution, θ...) = transformed(AdvancedVI.update(td.dist, θ...), td.transform)
95-
function AdvancedVI.update(td::TransformedDistribution{<:TuringDiagMvNormal}, θ::AbstractArray)
90+
function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ)
91+
return DistributionsAD.TuringDiagMvNormal(μ, σ)
92+
end
93+
function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...)
94+
return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform)
95+
end
96+
function AdvancedVI.update(
97+
td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal},
98+
θ::AbstractArray,
99+
)
96100
μ, ω = θ[1:length(td)], θ[length(td) + 1:end]
97-
return AdvancedVI.update(td, μ, softplus.(ω))
101+
return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω))
98102
end
99103

100-
function AdvancedVI.vi(model::Model, alg::ADVI; optimizer = TruncatedADAGrad())
104+
function AdvancedVI.vi(
105+
model::DynamicPPL.Model,
106+
alg::AdvancedVI.ADVI;
107+
optimizer = AdvancedVI.TruncatedADAGrad(),
108+
)
101109
q = meanfield(model)
102110
return AdvancedVI.vi(model, alg, q; optimizer = optimizer)
103111
end
104112

105113

106-
function AdvancedVI.vi(model::Model, alg::ADVI, q::TransformedDistribution{<:TuringDiagMvNormal}; optimizer = TruncatedADAGrad())
114+
function AdvancedVI.vi(
115+
model::DynamicPPL.Model,
116+
alg::AdvancedVI.ADVI,
117+
q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal};
118+
optimizer = AdvancedVI.TruncatedADAGrad(),
119+
)
107120
@debug "Optimizing ADVI..."
108121
# Initial parameters for mean-field approx
109-
μ, σs = params(q)
110-
θ = vcat(μ, invsoftplus.(σs))
122+
μ, σs = StatsBase.params(q)
123+
θ = vcat(μ, StatsFuns.invsoftplus.(σs))
111124

112125
# Optimize
113126
AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer = optimizer)

src/variational/objectives.jl

Lines changed: 0 additions & 21 deletions
This file was deleted.

src/variational/optimisers.jl

Lines changed: 0 additions & 94 deletions
This file was deleted.

test/variational/optimisers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using Random, Test, LinearAlgebra, ForwardDiff
2-
using Turing.Variational: TruncatedADAGrad, DecayedADAGrad, apply!
2+
using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply!
33

44
function test_opt(ADPack, opt)
55
θ = randn(10, 10)

0 commit comments

Comments
 (0)