-
Notifications
You must be signed in to change notification settings - Fork 35
[Merged by Bors] - Linearization/flattening of SimpleVarInfo #417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 173 commits
14594d6
0bc279f
81ee12e
d39f87d
23f34cc
d3ec108
81782c9
12bfb42
c2c2417
2e2cb5c
f6c3fc4
d643a78
3cab7d9
8b870dc
0b304db
b146b11
d170d92
f46183b
7a78eec
70b3b70
a03e8cf
793c931
a9b12fd
b1d7f9a
be98961
3e1588b
7697fce
3139c62
3610658
27171ad
cd2d9d6
d948cb9
6a3e18f
d674478
26d2dbb
3fcba56
83a9448
356fa9c
13f037f
c7544e0
d1dccf1
ff7ff4a
2f1a2ff
d6311b7
ed2fa69
a82be56
7aacee5
96f128f
2e88d08
0f9765b
116c95c
d797e99
44b2f66
81cd881
2e14abd
12adc83
af3e6ba
f7501df
abcabf4
fdee509
e974c83
6c6d5f5
5d5bc88
0498336
f86f264
0d31137
c52630b
fff060c
e21958c
9669345
7e02735
a412029
f3818c3
8b799a4
93cb298
ba5852b
328f713
801bd4c
46f6f4c
bcb767b
c5be1c2
f266929
c2dbbaf
1abb46c
1086c6c
f2fb4a5
490d24e
6350ccd
951e4c3
dcd92c9
2922ffa
5266a4b
3c38710
ba92f3f
70c864c
8b6b440
5843699
66f41a9
eb2d6b5
e8cdb91
359d384
0b20f09
ab0a99b
dd10913
18d28cc
32b7aab
fb86231
f782fe2
7d3493d
2b1893c
912d7f8
a276e4a
626eea2
1558924
a62c881
5cc195a
ea5a7a4
702f2ff
d8f4970
56f30bc
2eaef02
f0f981b
8063d1e
1e0b946
faa0e42
2935bde
78f22e1
025a4d4
9e7f493
a72e9b8
0a9383b
c057080
f5c60ae
d8b0a75
7149c02
25f05de
e05fa29
431664d
363ebae
66424f8
3bc27f8
ca5b080
cb05fc9
45445cf
ea8f844
52274ba
7da0ee9
b3499a3
2bd5dcd
61a594c
ce5f6e4
6c941bd
5e92e56
aabc45a
939540c
aecf97f
9dcefdb
fd0796b
2ef1f59
94e5d48
15fdf19
48dfb9c
70ba82d
d3bff26
b14e9cf
9f106fa
bf34356
3e5f763
e649f37
5941270
f30b875
cb3e1f4
f79fab4
58c2550
c316e70
998fcf4
0913a24
9af2638
482ade7
b79bf28
15087c5
88dbdca
9b3c40f
57d321c
4409149
c34f257
73765e7
0c0c393
5e51755
3dbc7a9
809de9a
656175f
a225978
fce67ee
b165b35
af9c520
47c30e3
40477a4
a972f8e
ab2a8b5
ee7fcd6
a44e712
3e869ff
08de024
0082505
597dfda
be7ae6c
8dfc7c0
e475c87
1e6b0a9
ce5757a
e23763b
43b034d
5c7df84
5c7163d
5f21dd7
e044676
effec2b
c7240ab
5fabd07
2e0fe49
965fcf5
200a886
5a0296e
a2e332e
d950635
e0907c1
6aceac4
e5d8984
63b3638
3ffeef1
8ed91ec
8ddfb4c
3246cf4
74b5d93
c6264e5
da04c7b
41fd89a
e7b8b10
732e94b
1c1b6ed
fbf9e0a
a427983
6b126b8
d715b0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,8 +1,3 @@ | ||||||
abstract type AbstractTransformation end | ||||||
|
||||||
struct NoTransformation <: AbstractTransformation end | ||||||
struct DefaultTransformation <: AbstractTransformation end | ||||||
|
||||||
""" | ||||||
$(TYPEDEF) | ||||||
|
||||||
|
@@ -197,6 +192,8 @@ struct SimpleVarInfo{NT,T,C<:AbstractTransformation} <: AbstractVarInfo | |||||
transformation::C | ||||||
end | ||||||
|
||||||
transformation(vi::SimpleVarInfo) = vi.transformation | ||||||
|
||||||
SimpleVarInfo(values, logp) = SimpleVarInfo(values, logp, NoTransformation()) | ||||||
|
||||||
function SimpleVarInfo{T}(θ) where {T<:Real} | ||||||
|
@@ -227,9 +224,17 @@ function SimpleVarInfo{T}( | |||||
return SimpleVarInfo(values, convert(T, getlogp(vi))) | ||||||
end | ||||||
|
||||||
SimpleVarInfo(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed? Seems a bit like introducing some of the surprising There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah no, good catch! |
||||||
|
||||||
unflatten(svi::SimpleVarInfo, spl, x::AbstractVector) = unflatten(svi, x) | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
function unflatten(svi::SimpleVarInfo, x::AbstractVector) | ||||||
return Setfield.@set svi.values = unflatten(svi.values, x) | ||||||
end | ||||||
|
||||||
function BangBang.empty!!(vi::SimpleVarInfo) | ||||||
Setfield.@set resetlogp!!(vi).values = empty!!(vi.values) | ||||||
return resetlogp!!(Setfield.@set vi.values = empty!!(vi.values)) | ||||||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
end | ||||||
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values) | ||||||
|
||||||
getlogp(vi::SimpleVarInfo) = vi.logp | ||||||
setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp | ||||||
|
@@ -308,11 +313,8 @@ end | |||||
# HACK: Needed to disambiguiate. | ||||||
Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}) = map(Base.Fix1(getindex, vi), vns) | ||||||
|
||||||
Base.getindex(vi::SimpleVarInfo, spl::SampleFromPrior) = vi.values | ||||||
Base.getindex(vi::SimpleVarInfo, spl::SampleFromUniform) = vi.values | ||||||
|
||||||
# TODO: Should we do better? | ||||||
Base.getindex(vi::SimpleVarInfo, spl::Sampler) = vi.values | ||||||
Base.getindex(svi::SimpleVarInfo, ::Colon) = values_as(svi, Vector) | ||||||
Base.getindex(svi::SimpleVarInfo, ::Sampler) = svi[:] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
Suggested change
in line with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will cause method ambiguity; there are definitions for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could one generalize the definitons in src/varinfo.jl as well to fix those ambiguities? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could just remove the two impls for |
||||||
|
||||||
# Since we don't perform any transformations in `getindex` for `SimpleVarInfo` | ||||||
# we simply call `getindex` in `getindex_raw`. | ||||||
|
@@ -365,6 +367,10 @@ function BangBang.setindex!!(vi::SimpleVarInfo, val, vn::VarName) | |||||
return Setfield.@set vi.values = set!!(vi.values, vn, val) | ||||||
end | ||||||
|
||||||
function BangBang.setindex!!(vi::SimpleVarInfo, val, spl::AbstractSampler) | ||||||
return unflatten(vi, spl, val) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe consider depreciating the API or remove it? The |
||||||
end | ||||||
|
||||||
# TODO: Specialize to handle certain cases, e.g. a collection of `VarName` with | ||||||
# same symbol and same type of, say, `IndexLens`, for improved `.~` performance. | ||||||
function BangBang.setindex!!(vi::SimpleVarInfo, vals, vns::AbstractVector{<:VarName}) | ||||||
|
@@ -509,6 +515,45 @@ end | |||||
|
||||||
# HACK: Allows us to re-use the implementation of `dot_tilde`, etc. for literals. | ||||||
increment_num_produce!(::SimpleOrThreadSafeSimple) = nothing | ||||||
setgid!(vi::SimpleOrThreadSafeSimple, gid::Selector, vn::VarName) = nothing | ||||||
devmotion marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
# We need these to be compatible with how chains are constructed from `AbstractVarInfo` in Turing.jl. | ||||||
# TODO: Move away from using these `tonamedtuple` methods. | ||||||
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:NamedTuple{names}}) where {names} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we unify the approaches for getting named tuples, vectors etc? E.g. by using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So IIRC it also causes insane performance issues for larger models when constructing the chains. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be great to open an issue for this. |
||||||
nt_vals = map(keys(vi)) do vn | ||||||
val = vi[vn] | ||||||
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) | ||||||
vals = map(Base.Fix1(getindex, vi), vns) | ||||||
(vals, map(string, vns)) | ||||||
end | ||||||
|
||||||
return NamedTuple{names}(nt_vals) | ||||||
end | ||||||
|
||||||
function tonamedtuple(vi::SimpleOrThreadSafeSimple{<:Dict}) | ||||||
syms_to_result = Dict{Symbol,Tuple{Vector{Real},Vector{String}}}() | ||||||
for vn in keys(vi) | ||||||
# Extract the leaf varnames and values. | ||||||
val = vi[vn] | ||||||
vns = collect(DynamicPPL.TestUtils.varname_leaves(vn, val)) | ||||||
vals = map(Base.Fix1(getindex, vi), vns) | ||||||
|
||||||
# Determine the corresponding symbol. | ||||||
sym = only(unique(map(getsym, vns))) | ||||||
|
||||||
# Initialize entry if not yet initialized. | ||||||
if !haskey(syms_to_result, sym) | ||||||
syms_to_result[sym] = (Real[], String[]) | ||||||
end | ||||||
|
||||||
# Combine with old result. | ||||||
old_vals, old_string_vns = syms_to_result[sym] | ||||||
syms_to_result[sym] = (vcat(old_vals, vals), vcat(old_string_vns, map(string, vns))) | ||||||
end | ||||||
|
||||||
# Construct `NamedTuple`. | ||||||
return NamedTuple(pairs(syms_to_result)) | ||||||
end | ||||||
|
||||||
# NOTE: We don't implement `settrans!!(vi, trans, vn)`. | ||||||
function settrans!!(vi::SimpleVarInfo, trans) | ||||||
|
@@ -525,6 +570,8 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation) | |||||
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi) | ||||||
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn) | ||||||
|
||||||
islinked(vi::SimpleVarInfo, ::Union{Sampler,SampleFromPrior}) = istrans(vi) | ||||||
|
||||||
""" | ||||||
values_as(varinfo[, Type]) | ||||||
|
||||||
|
@@ -536,6 +583,10 @@ values_as(vi::SimpleVarInfo) = vi.values | |||||
values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values)) | ||||||
values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values)) | ||||||
values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values | ||||||
function values_as(vi::SimpleVarInfo{<:Any,T}, ::Type{Vector}) where {T} | ||||||
isempty(vi.values) && return T[] | ||||||
return mapreduce(v -> vec([v;]), vcat, values(vi.values)) | ||||||
end | ||||||
|
||||||
""" | ||||||
logjoint(model::Model, θ) | ||||||
|
@@ -632,3 +683,35 @@ julia> # Truth. | |||||
``` | ||||||
""" | ||||||
Distributions.loglikelihood(model::Model, θ) = loglikelihood(model, SimpleVarInfo(θ)) | ||||||
|
||||||
# Allow usage of `NamedBijector` too. | ||||||
function link!!( | ||||||
t::BijectorTransformation{<:Bijectors.NamedBijector}, | ||||||
vi::SimpleVarInfo{<:NamedTuple}, | ||||||
spl::AbstractSampler, | ||||||
model::Model, | ||||||
) | ||||||
# TODO: Make sure that `spl` is respected. | ||||||
b = t.bijector | ||||||
x = vi.values | ||||||
y, logjac = with_logabsdet_jacobian(b, x) | ||||||
lp_new = getlogp(vi) - logjac | ||||||
vi_new = setlogp!!(Setfield.@set(vi.values = y), lp_new) | ||||||
return settrans!!(vi_new, t) | ||||||
end | ||||||
|
||||||
function invlink!!( | ||||||
t::BijectorTransformation{<:Bijectors.NamedBijector}, | ||||||
vi::SimpleVarInfo{<:NamedTuple}, | ||||||
spl::AbstractSampler, | ||||||
model::Model, | ||||||
) | ||||||
# TODO: Make sure that `spl` is respected. | ||||||
b = t.bijector | ||||||
ib = inverse(b) | ||||||
y = vi.values | ||||||
x, logjac = with_logabsdet_jacobian(ib, y) | ||||||
lp_new = getlogp(vi) - logjac | ||||||
vi_new = setlogp!!(Setfield.@set(vi.values = x), lp_new) | ||||||
return settrans!!(vi_new, NoTransformation()) | ||||||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
function Bijectors.Stacked( | ||
model::Model, | ||
::Val{sym2ranges}=Val(false); | ||
varinfo::VarInfo=VarInfo(model), | ||
torfjelde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) where {sym2ranges} | ||
dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) | ||
|
||
num_ranges = sum([ | ||
length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) | ||
]) | ||
ranges = Vector{UnitRange{Int}}(undef, num_ranges) | ||
idx = 0 | ||
range_idx = 1 | ||
|
||
# ranges might be discontinuous => values are vectors of ranges rather than just ranges | ||
sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() | ||
for sym in keys(varinfo.metadata) | ||
sym_lookup[sym] = Vector{UnitRange{Int}}() | ||
for r in varinfo.metadata[sym].ranges | ||
ranges[range_idx] = idx .+ r | ||
push!(sym_lookup[sym], ranges[range_idx]) | ||
range_idx += 1 | ||
end | ||
|
||
idx += varinfo.metadata[sym].ranges[end][end] | ||
end | ||
|
||
b = Bijectors.Stacked(map(Bijectors.bijector, dists), ranges) | ||
return sym2ranges ? (b, Dict(zip(keys(sym_lookup), values(sym_lookup)))) : b | ||
end | ||
|
||
link!!(vi::AbstractVarInfo, model::Model) = link!!(vi, SampleFromPrior(), model) | ||
function link!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) | ||
return link!!(t, vi, SampleFromPrior(), model) | ||
end | ||
function link!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) | ||
# Use `default_transformation` to decide which transformation to use if none is specified. | ||
return link!!(default_transformation(model, vi), vi, spl, model) | ||
end | ||
function link!!( | ||
t::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model | ||
) | ||
# TODO: Implement this properly, e.g. using a context or something. | ||
# Fall back to `Bijectors.Stacked` but then we act like we're using | ||
# the `DefaultTransformation` by setting the transformation accordingly. | ||
return settrans!!( | ||
link!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model), t | ||
) | ||
end | ||
function link!!(t::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) | ||
# TODO: Implement this properly, e.g. using a context or something. | ||
DynamicPPL.link!(vi, spl) | ||
# TODO: Add `logabsdet_jacobian` correction to `logp`! | ||
return vi | ||
end | ||
function link!!( | ||
t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model | ||
) | ||
b = t.bijector | ||
x = vi[spl] | ||
y, logjac = with_logabsdet_jacobian(b, x) | ||
|
||
lp_new = getlogp(vi) - logjac | ||
vi_new = setlogp!!(unflatten(vi, spl, y), lp_new) | ||
return settrans!!(vi_new, t) | ||
end | ||
|
||
invlink!!(vi::AbstractVarInfo, model::Model) = invlink!!(vi, SampleFromPrior(), model) | ||
function invlink!!(t::AbstractTransformation, vi::AbstractVarInfo, model::Model) | ||
return invlink!!(t, vi, SampleFromPrior(), model) | ||
end | ||
function invlink!!(vi::AbstractVarInfo, spl::AbstractSampler, model::Model) | ||
# Here we extract the `transformation` from `vi` rather than using the default one. | ||
return invlink!!(transformation(vi), vi, spl, model) | ||
end | ||
function invlink!!( | ||
::DefaultTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model | ||
) | ||
# TODO: Implement this properly, e.g. using a context or something. | ||
return invlink!!(BijectorTransformation(Bijectors.Stacked(model)), vi, spl, model) | ||
end | ||
function invlink!!(::DefaultTransformation, vi::VarInfo, spl::AbstractSampler, model::Model) | ||
# TODO: Implement this properly, e.g. using a context or something. | ||
DynamicPPL.invlink!(vi, spl) | ||
return vi | ||
end | ||
function invlink!!( | ||
t::BijectorTransformation, vi::AbstractVarInfo, spl::AbstractSampler, model::Model | ||
) | ||
b = t.bijector | ||
ib = inverse(b) | ||
y = vi[spl] | ||
x, logjac = with_logabsdet_jacobian(ib, y) | ||
# TODO: Do we need this? | ||
lp_new = getlogp(vi) - logjac | ||
vi_new = setlogp!!(unflatten(vi, spl, x), lp_new) | ||
return settrans!!(vi_new, NoTransformation()) | ||
end |
Uh oh!
There was an error while loading. Please reload this page.