Skip to content

Commit e31a790

Browse files
committed
Use OrderedDict in SimpleVarInfo + improvements and fixes for values_as (#420)
We are currently using `Dict` together with `SimpleVarInfo` which leads to inconsistent ordering of the variables vs. `OrderdDict` which, if generated from a `Model`, will preserve the execution order of the model. In addition, I've fixed some impls for `values_as` + added more better support, in addition to proper testing. Given how it's now better tested + is a nice-to-have feature + will likely see extensive use after #417, it also seems reasonable to export `values_as` from DPPL. EDIT: This should be merged before #417
1 parent 08ef935 commit e31a790

File tree

10 files changed

+223
-31
lines changed

10 files changed

+223
-31
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1313
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
16+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
1617
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1718
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
1819
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -28,6 +29,7 @@ ConstructionBase = "1"
2829
Distributions = "0.23.8, 0.24, 0.25"
2930
DocStringExtensions = "0.8, 0.9"
3031
MacroTools = "0.5.6"
32+
OrderedCollections = "1"
3133
Setfield = "0.7.1, 0.8"
3234
ZygoteRules = "0.2"
3335
julia = "1.6"

docs/src/api.md

+4
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,10 @@ push!!
167167
empty!!
168168
```
169169

170+
```@docs
171+
values_as
172+
```
173+
170174
#### `SimpleVarInfo`
171175

172176
```@docs

src/DynamicPPL.jl

+4
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ using AbstractMCMC: AbstractSampler, AbstractChains
44
using AbstractPPL
55
using Bijectors
66
using Distributions
7+
using OrderedCollections: OrderedDict
78

89
using AbstractMCMC: AbstractMCMC
910
using BangBang: BangBang, push!!, empty!!, setindex!!
1011
using ChainRulesCore: ChainRulesCore
1112
using MacroTools: MacroTools
13+
using ConstructionBase: ConstructionBase
1214
using Setfield: Setfield
1315
using ZygoteRules: ZygoteRules
1416

@@ -59,6 +61,7 @@ export AbstractVarInfo,
5961
link!,
6062
invlink!,
6163
tonamedtuple,
64+
values_as,
6265
# VarName (reexport from AbstractPPL)
6366
VarName,
6467
inspace,
@@ -73,6 +76,7 @@ export AbstractVarInfo,
7376
Sample,
7477
init,
7578
vectorize,
79+
OrderedDict,
7680
# Model
7781
Model,
7882
getmissings,

src/model.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,7 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
625625
x = last(
626626
evaluate!!(
627627
model,
628-
SimpleVarInfo{Float64}(),
628+
SimpleVarInfo{Float64}(OrderedDict()),
629629
SamplingContext(rng, SampleFromPrior(), DefaultContext()),
630630
),
631631
)

src/simple_varinfo.jl

+21-19
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ struct DefaultTransformation <: AbstractTransformation end
99
A simple wrapper of the parameters with a `logp` field for
1010
accumulation of the logdensity.
1111
12-
Currently only implemented for `NT<:NamedTuple` and `NT<:Dict`.
12+
Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict`.
1313
1414
# Fields
1515
$(FIELDS)
@@ -69,8 +69,8 @@ julia> # (×) If we don't provide the container...
6969
ERROR: type NamedTuple has no field x
7070
[...]
7171
72-
julia> # If one does not know the varnames, we can use a `Dict` instead.
73-
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict()), ctx);
72+
julia> # If one does not know the varnames, we can use a `OrderedDict` instead.
73+
_, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict()), ctx);
7474
7575
julia> # (✓) Sort of fast, but only possible at runtime.
7676
vi[@varname(x[1])]
@@ -86,6 +86,11 @@ ERROR: KeyError: key x[1:2] not found
8686
[...]
8787
```
8888
89+
_Technically_, it's possible to use any implementation of `AbstractDict` in place of
90+
`OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening
91+
of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is
92+
the preferred implementation of `AbstractDict` to use here.
93+
8994
You can also sample in _transformed_ space:
9095
9196
```jldoctest simplevarinfo-general
@@ -109,8 +114,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo()
109114
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
110115
true
111116
112-
julia> # And with `Dict` of course!
113-
_, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict()), true), ctx);
117+
julia> # And with `OrderedDict` of course!
118+
_, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict()), true), ctx);
114119
115120
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
116121
0.6225185067787314
@@ -165,9 +170,9 @@ ERROR: type NamedTuple has no field b
165170
[...]
166171
```
167172
168-
Using `Dict` as underlying storage.
173+
Using `OrderedDict` as underlying storage.
169174
```jldoctest
170-
julia> svi_dict = SimpleVarInfo(Dict(@varname(m) => (a = [1.0], )));
175+
julia> svi_dict = SimpleVarInfo(OrderedDict(@varname(m) => (a = [1.0], )));
171176
172177
julia> svi_dict[@varname(m)]
173178
(a = [1.0],)
@@ -274,7 +279,7 @@ end
274279

275280
Base.getindex(vi::SimpleVarInfo, vn::VarName) = get(vi.values, vn)
276281

277-
# `Dict`
282+
# `AbstractDict`
278283
function Base.getindex(vi::SimpleVarInfo{<:AbstractDict}, vn::VarName)
279284
return nested_getindex(vi.values, vn)
280285
end
@@ -364,7 +369,7 @@ function BangBang.push!!(
364369
return Setfield.@set vi.values = set!!(vi.values, vn, value)
365370
end
366371

367-
# `Dict`
372+
# `AbstractDict`
368373
function BangBang.push!!(
369374
vi::SimpleVarInfo{<:AbstractDict},
370375
vn::VarName,
@@ -473,17 +478,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
473478
istrans(vi::SimpleVarInfo, vn::VarName) = istrans(vi)
474479
istrans(vi::ThreadSafeVarInfo{<:SimpleVarInfo}, vn::VarName) = istrans(vi.varinfo, vn)
475480

476-
"""
477-
values_as(varinfo[, Type])
478-
479-
Return the values/realizations in `varinfo` as `Type`, if implemented.
480-
481-
If no `Type` is provided, return values as stored in `varinfo`.
482-
"""
483481
values_as(vi::SimpleVarInfo) = vi.values
484-
values_as(vi::SimpleVarInfo, ::Type{Dict}) = Dict(pairs(vi.values))
485-
values_as(vi::SimpleVarInfo, ::Type{NamedTuple}) = NamedTuple(pairs(vi.values))
486-
values_as(vi::SimpleVarInfo{<:NamedTuple}, ::Type{NamedTuple}) = vi.values
482+
values_as(vi::SimpleVarInfo{<:T}, ::Type{T}) where {T} = vi.values
483+
function values_as(vi::SimpleVarInfo, ::Type{D}) where {D<:AbstractDict}
484+
return ConstructionBase.constructorof(D)(zip(keys(vi), values(vi.values)))
485+
end
486+
function values_as(vi::SimpleVarInfo{<:AbstractDict}, ::Type{NamedTuple})
487+
return NamedTuple((Symbol(k), v) for (k, v) in vi.values)
488+
end
487489

488490
"""
489491
logjoint(model::Model, θ)

src/varinfo.jl

+91-10
Original file line numberDiff line numberDiff line change
@@ -1550,30 +1550,111 @@ function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys)
15501550
end
15511551

15521552
"""
1553-
values_as(vi::AbstractVarInfo)
1554-
"""
1555-
values_as(vi::VarInfo) = vi.metadata
1553+
values_as(varinfo[, Type])
15561554
1557-
"""
1558-
values_as(vi::AbstractVarInfo, ::Type{NamedTuple})
1559-
values_as(vi::AbstractVarInfo, ::Type{Dict})
1555+
Return the values/realizations in `varinfo` as `Type`, if implemented.
1556+
1557+
If no `Type` is provided, return values as stored in `varinfo`.
1558+
1559+
# Examples
1560+
1561+
`SimpleVarInfo` with `NamedTuple`:
1562+
1563+
```jldoctest
1564+
julia> data = (x = 1.0, m = [2.0]);
1565+
1566+
julia> values_as(SimpleVarInfo(data))
1567+
(x = 1.0, m = [2.0])
1568+
1569+
julia> values_as(SimpleVarInfo(data), NamedTuple)
1570+
(x = 1.0, m = [2.0])
1571+
1572+
julia> values_as(SimpleVarInfo(data), OrderedDict)
1573+
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Any} with 2 entries:
1574+
x => 1.0
1575+
m => [2.0]
1576+
```
1577+
1578+
`SimpleVarInfo` with `OrderedDict`:
1579+
1580+
```jldoctest
1581+
julia> data = OrderedDict{Any,Any}(@varname(x) => 1.0, @varname(m) => [2.0]);
1582+
1583+
julia> values_as(SimpleVarInfo(data))
1584+
OrderedDict{Any, Any} with 2 entries:
1585+
x => 1.0
1586+
m => [2.0]
1587+
1588+
julia> values_as(SimpleVarInfo(data), NamedTuple)
1589+
(x = 1.0, m = [2.0])
1590+
1591+
julia> values_as(SimpleVarInfo(data), OrderedDict)
1592+
OrderedDict{Any, Any} with 2 entries:
1593+
x => 1.0
1594+
m => [2.0]
1595+
```
1596+
1597+
`TypedVarInfo`:
1598+
1599+
```jldoctest
1600+
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
1601+
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
1602+
1603+
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
1604+
1605+
julia> # For the sake of brevity, let's just check the type.
1606+
md = values_as(vi); md.s isa DynamicPPL.Metadata
1607+
true
1608+
1609+
julia> values_as(vi, NamedTuple)
1610+
(s = 1.0, m = 2.0)
1611+
1612+
julia> values_as(vi, OrderedDict)
1613+
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
1614+
s => 1.0
1615+
m => 2.0
1616+
```
15601617
1561-
Return values in `vi` as the specified type.
1618+
`UntypedVarInfo`:
1619+
1620+
```jldoctest
1621+
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
1622+
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
1623+
1624+
julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;
1625+
1626+
julia> # For the sake of brevity, let's just check the type.
1627+
values_as(vi) isa DynamicPPL.Metadata
1628+
true
1629+
1630+
julia> values_as(vi, NamedTuple)
1631+
(s = 1.0, m = 2.0)
1632+
1633+
julia> values_as(vi, OrderedDict)
1634+
OrderedDict{VarName{sym, Setfield.IdentityLens} where sym, Float64} with 2 entries:
1635+
s => 1.0
1636+
m => 2.0
1637+
```
15621638
"""
1639+
values_as(vi::VarInfo) = vi.metadata
15631640
function values_as(vi::UntypedVarInfo, ::Type{NamedTuple})
15641641
iter = values_from_metadata(vi.metadata)
15651642
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
15661643
end
1567-
values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata))
1644+
function values_as(vi::UntypedVarInfo, ::Type{D}) where {D<:AbstractDict}
1645+
return ConstructionBase.constructorof(D)(values_from_metadata(vi.metadata))
1646+
end
15681647

15691648
function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names}
15701649
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
15711650
return NamedTuple(map(p -> Symbol(p.first) => p.second, iter))
15721651
end
15731652

1574-
function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names}
1653+
function values_as(
1654+
vi::VarInfo{<:NamedTuple{names}}, ::Type{D}
1655+
) where {names,D<:AbstractDict}
15751656
iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names)
1576-
return Dict(iter)
1657+
return ConstructionBase.constructorof(D)(iter)
15771658
end
15781659

15791660
function values_from_metadata(md::Metadata)

src/varname.jl

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
# FIXME: This fix should be in `AbstractPPL`.
2+
AbstractPPL.subsumes(::Setfield.IdentityLens, ::Setfield.IdentityLens) = true
3+
14
"""
25
subsumes_string(u::String, v::String[, u_indexing])
36

test/model.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ end
134134
Random.seed!(1776)
135135
s, m = model()
136136
sample_namedtuple = (; s=s, m=m)
137-
sample_dict = Dict(:s => s, :m => m)
137+
sample_dict = Dict(@varname(s) => s, @varname(m) => m)
138138

139139
# With explicit RNG
140140
@test rand(Random.seed!(1776), model) == sample_namedtuple

test/test_util.jl

+55
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,58 @@ function test_setval!(model, chain; sample_idx=1, chain_idx=1)
7474
end
7575
end
7676
end
77+
78+
"""
79+
short_varinfo_name(vi::AbstractVarInfo)
80+
81+
Return string representing a short description of `vi`.
82+
"""
83+
short_varinfo_name(vi::DynamicPPL.ThreadSafeVarInfo) = short_varinfo_name(vi.varinfo)
84+
short_varinfo_name(::TypedVarInfo) = "TypedVarInfo"
85+
short_varinfo_name(::UntypedVarInfo) = "UntypedVarInfo"
86+
short_varinfo_name(::SimpleVarInfo{<:NamedTuple}) = "SimpleVarInfo{<:NamedTuple}"
87+
short_varinfo_name(::SimpleVarInfo{<:OrderedDict}) = "SimpleVarInfo{<:OrderedDict}"
88+
89+
"""
90+
update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
91+
92+
Return instance similar to `vi` but with `vns` set to values from `vals`.
93+
"""
94+
function update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns)
95+
for vn in vns
96+
vi = DynamicPPL.setindex!!(vi, get(vals, vn), vn)
97+
end
98+
return vi
99+
end
100+
101+
"""
102+
test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
103+
104+
Test that `vi[vn]` corresponds to the correct value in `vals` for every `vn` in `vns`.
105+
"""
106+
function test_values(vi::AbstractVarInfo, vals::NamedTuple, vns)
107+
for vn in vns
108+
@test vi[vn] == get(vals, vn)
109+
end
110+
end
111+
112+
"""
113+
setup_varinfos(model::Model, example_values::NamedTuple, varnames)
114+
115+
Return a tuple of instances for different implementations of `AbstractVarInfo` with
116+
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
117+
"""
118+
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
119+
# <:VarInfo
120+
vi_untyped = VarInfo()
121+
model(vi_untyped)
122+
vi_typed = TypedVarInfo(vi_untyped)
123+
# <:SimpleVarInfo
124+
svi_typed = SimpleVarInfo(example_values)
125+
svi_untyped = SimpleVarInfo(OrderedDict())
126+
127+
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
128+
# Set them all to the same values.
129+
update_values!!(vi, example_values, varnames)
130+
end
131+
end

0 commit comments

Comments
 (0)