@@ -9,7 +9,7 @@ struct DefaultTransformation <: AbstractTransformation end
9
9
A simple wrapper of the parameters with a `logp` field for
10
10
accumulation of the logdensity.
11
11
12
- Currently only implemented for `NT<:NamedTuple` and `NT<:Dict `.
12
+ Currently only implemented for `NT<:NamedTuple` and `NT<:AbstractDict `.
13
13
14
14
# Fields
15
15
$(FIELDS)
@@ -69,8 +69,8 @@ julia> # (×) If we don't provide the container...
69
69
ERROR: type NamedTuple has no field x
70
70
[...]
71
71
72
- julia> # If one does not know the varnames, we can use a `Dict ` instead.
73
- _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(Dict ()), ctx);
72
+ julia> # If one does not know the varnames, we can use a `OrderedDict ` instead.
73
+ _, vi = DynamicPPL.evaluate!!(m, SimpleVarInfo{Float64}(OrderedDict ()), ctx);
74
74
75
75
julia> # (✓) Sort of fast, but only possible at runtime.
76
76
vi[@varname(x[1])]
@@ -86,6 +86,11 @@ ERROR: KeyError: key x[1:2] not found
86
86
[...]
87
87
```
88
88
89
+ _Technically_, it's possible to use any implementation of `AbstractDict` in place of
90
+ `OrderedDict`, but `OrderedDict` ensures that certain operations, e.g. linearization/flattening
91
+ of the values in the varinfo, are consistent between evaluations. Hence `OrderedDict` is
92
+ the preferred implementation of `AbstractDict` to use here.
93
+
89
94
You can also sample in _transformed_ space:
90
95
91
96
```jldoctest simplevarinfo-general
@@ -109,8 +114,8 @@ julia> xs = [last(DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo()
109
114
julia> any(xs .< 0) # (✓) Positive probability mass on negative numbers!
110
115
true
111
116
112
- julia> # And with `Dict ` of course!
113
- _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(Dict ()), true), ctx);
117
+ julia> # And with `OrderedDict ` of course!
118
+ _, vi = DynamicPPL.evaluate!!(m, DynamicPPL.settrans!!(SimpleVarInfo(OrderedDict ()), true), ctx);
114
119
115
120
julia> vi[@varname(x)] # (✓) -∞ < x < ∞
116
121
0.6225185067787314
@@ -165,9 +170,9 @@ ERROR: type NamedTuple has no field b
165
170
[...]
166
171
```
167
172
168
- Using `Dict ` as underlying storage.
173
+ Using `OrderedDict ` as underlying storage.
169
174
```jldoctest
170
- julia> svi_dict = SimpleVarInfo(Dict (@varname(m) => (a = [1.0], )));
175
+ julia> svi_dict = SimpleVarInfo(OrderedDict (@varname(m) => (a = [1.0], )));
171
176
172
177
julia> svi_dict[@varname(m)]
173
178
(a = [1.0],)
274
279
275
280
Base. getindex (vi:: SimpleVarInfo , vn:: VarName ) = get (vi. values, vn)
276
281
277
- # `Dict `
282
+ # `AbstractDict `
278
283
function Base. getindex (vi:: SimpleVarInfo{<:AbstractDict} , vn:: VarName )
279
284
return nested_getindex (vi. values, vn)
280
285
end
@@ -364,7 +369,7 @@ function BangBang.push!!(
364
369
return Setfield. @set vi. values = set!! (vi. values, vn, value)
365
370
end
366
371
367
- # `Dict `
372
+ # `AbstractDict `
368
373
function BangBang. push!! (
369
374
vi:: SimpleVarInfo{<:AbstractDict} ,
370
375
vn:: VarName ,
@@ -473,17 +478,14 @@ istrans(vi::SimpleVarInfo) = !(vi.transformation isa NoTransformation)
473
478
istrans (vi:: SimpleVarInfo , vn:: VarName ) = istrans (vi)
474
479
istrans (vi:: ThreadSafeVarInfo{<:SimpleVarInfo} , vn:: VarName ) = istrans (vi. varinfo, vn)
475
480
476
- """
477
- values_as(varinfo[, Type])
478
-
479
- Return the values/realizations in `varinfo` as `Type`, if implemented.
480
-
481
- If no `Type` is provided, return values as stored in `varinfo`.
482
- """
483
481
values_as (vi:: SimpleVarInfo ) = vi. values
484
- values_as (vi:: SimpleVarInfo , :: Type{Dict} ) = Dict (pairs (vi. values))
485
- values_as (vi:: SimpleVarInfo , :: Type{NamedTuple} ) = NamedTuple (pairs (vi. values))
486
- values_as (vi:: SimpleVarInfo{<:NamedTuple} , :: Type{NamedTuple} ) = vi. values
482
+ values_as (vi:: SimpleVarInfo{<:T} , :: Type{T} ) where {T} = vi. values
483
+ function values_as (vi:: SimpleVarInfo , :: Type{D} ) where {D<: AbstractDict }
484
+ return ConstructionBase. constructorof (D)(zip (keys (vi), values (vi. values)))
485
+ end
486
+ function values_as (vi:: SimpleVarInfo{<:AbstractDict} , :: Type{NamedTuple} )
487
+ return NamedTuple ((Symbol (k), v) for (k, v) in vi. values)
488
+ end
487
489
488
490
"""
489
491
logjoint(model::Model, θ)
0 commit comments