Skip to content

Unify {untyped,typed}_{vector_,}varinfo constructor functions #879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,35 @@

**Breaking changes**

### VarInfo constructor
### VarInfo constructors

`VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead.

**The `VarInfo([rng, ]model[, sampler, context, metadata])` constructor has been replaced with the following methods:**

1. `UntypedVarInfo([rng, ]model[, sampler, context])`
2. `TypedVarInfo([rng, ]model[, sampler, context])`
3. `DynamicPPL.UntypedVectorVarInfo([rng, ]model[, sampler, context])`
4. `DynamicPPL.TypedVectorVarInfo([rng, ]model[, sampler, context])`

**If you were not using the `metadata` argument (most likely), then you can directly replace calls to this constructor with `TypedVarInfo` instead.**
That is to say, if you were using `VarInfo(model)`, you can replace this with `TypedVarInfo(model)`.

If you were using the `metadata` argument to specify a blank `VarNamedVector`, then you should replace calls to `VarInfo` with `TypedVectorVarInfo` instead.
Note that the `VectorVarInfo` constructors (both `Untyped` and `Typed`) are not exported by default.

If you were passing a non-empty metadata argument, you should use a different constructor of `VarInfo` instead.

The reason for this change is that there were several flavours of VarInfo.
Some, like TypedVarInfo, were easy to construct because we had convenience methods for them; however, the others were more difficult.
This change makes it easier to access different VarInfo types, and also makes it more explicit which one you are constructing.

The `untyped_varinfo` and `typed_varinfo` functions have also been removed; you can use `UntypedVarInfo` and `TypedVarInfo` as direct replacements.

Finally, `TypedVarInfo` is no longer a type.
It has been replaced with `NTVarInfo`.
If you were dispatching on this, you should replace it with `NTVarInfo` instead.

### VarName prefixing behaviour

The way in which VarNames in submodels are prefixed has been changed.
Expand Down
12 changes: 5 additions & 7 deletions benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module DynamicPPLBenchmarks

using DynamicPPL: VarInfo, SimpleVarInfo, VarName
using DynamicPPL: VarInfo, UntypedVarInfo, TypedVarInfo, SimpleVarInfo, VarName
using BenchmarkTools: BenchmarkGroup, @benchmarkable
using DynamicPPL: DynamicPPL
using ADTypes: ADTypes
Expand Down Expand Up @@ -52,8 +52,8 @@ end

Create a benchmark suite for `model` using the selected varinfo type and AD backend.
Available varinfo choices:
• `:untyped` → uses `VarInfo()`
• `:typed` → uses `VarInfo(model)`
• `:untyped` → uses `UntypedVarInfo(model)`
• `:typed` → uses `TypedVarInfo(model)`
• `:simple_namedtuple` → uses `SimpleVarInfo{Float64}(model())`
• `:simple_dict` → builds a `SimpleVarInfo{Float64}` from a Dict (pre-populated with the model’s outputs)

Expand All @@ -67,11 +67,9 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
suite = BenchmarkGroup()

vi = if varinfo_choice == :untyped
vi = VarInfo()
model(rng, vi)
vi
UntypedVarInfo(rng, model)
elseif varinfo_choice == :typed
VarInfo(rng, model)
TypedVarInfo(rng, model)
elseif varinfo_choice == :simple_namedtuple
SimpleVarInfo{Float64}(model(rng))
elseif varinfo_choice == :simple_dict
Expand Down
10 changes: 2 additions & 8 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -291,17 +291,11 @@ AbstractVarInfo

But exactly how a [`AbstractVarInfo`](@ref) stores this information can vary.

For constructing the "default" typed and untyped varinfo types used in DynamicPPL (see [the section on varinfo design](@ref "Design of `VarInfo`") for more on this), we have the following two methods:

```@docs
DynamicPPL.untyped_varinfo
DynamicPPL.typed_varinfo
```

#### `VarInfo`

```@docs
VarInfo
UntypedVarInfo
TypedVarInfo
```

Expand Down Expand Up @@ -455,7 +449,7 @@ Finally, to specify which varinfo type a [`Sampler`](@ref) should use for a give
DynamicPPL.default_varinfo
```

There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.typed_varinfo`](@ref) or [`DynamicPPL.untyped_varinfo`](@ref), depending on which supports the model:
There is also the _experimental_ [`DynamicPPL.Experimental.determine_suitable_varinfo`](@ref), which uses static checking via [JET.jl](https://github.com/aviatesk/JET.jl) to determine whether one should use [`DynamicPPL.TypedVarInfo`](@ref) or [`DynamicPPL.UntypedVarInfo`](@ref), depending on which supports the model:

```@docs
DynamicPPL.Experimental.determine_suitable_varinfo
Expand Down
10 changes: 5 additions & 5 deletions docs/src/internals/varinfo.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,13 @@ For example, with the model above we have

```@example varinfo-design
# Type-unstable `VarInfo`
varinfo_untyped = DynamicPPL.untyped_varinfo(demo())
varinfo_untyped = DynamicPPL.UntypedVarInfo(demo())
typeof(varinfo_untyped.metadata)
```

```@example varinfo-design
# Type-stable `VarInfo`
varinfo_typed = DynamicPPL.typed_varinfo(demo())
varinfo_typed = DynamicPPL.TypedVarInfo(demo())
typeof(varinfo_typed.metadata)
```

Expand Down Expand Up @@ -154,7 +154,7 @@ For example, we want to optimize code-paths which effectively boil down to inner

```julia
# Construct a `VarInfo` with types inferred from `model`.
varinfo = VarInfo(model)
varinfo = TypedVarInfo(model)

# Repeatedly sample from `model`.
for _ in 1:num_samples
Expand Down Expand Up @@ -227,13 +227,13 @@ Continuing from the example from the previous section, we can use a `VarInfo` wi

```@example varinfo-design
# Type-unstable
varinfo_untyped_vnv = DynamicPPL.VectorVarInfo(varinfo_untyped)
varinfo_untyped_vnv = DynamicPPL.UntypedVectorVarInfo(varinfo_untyped)
varinfo_untyped_vnv[@varname(x)], varinfo_untyped_vnv[@varname(y)]
```

```@example varinfo-design
# Type-stable
varinfo_typed_vnv = DynamicPPL.VectorVarInfo(varinfo_typed)
varinfo_typed_vnv = DynamicPPL.TypedVectorVarInfo(varinfo_typed)
varinfo_typed_vnv[@varname(x)], varinfo_typed_vnv[@varname(y)]
```

Expand Down
4 changes: 2 additions & 2 deletions ext/DynamicPPLJETExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
model::DynamicPPL.Model, context::DynamicPPL.AbstractContext; only_ddpl::Bool=true
)
# First we try with the typed varinfo.
varinfo = DynamicPPL.typed_varinfo(model, context)
varinfo = DynamicPPL.TypedVarInfo(model, context)

Check warning on line 30 in ext/DynamicPPLJETExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLJETExt.jl#L30

Added line #L30 was not covered by tests

# Let's make sure that both evaluation and sampling doesn't result in type errors.
issuccess, result = DynamicPPL.Experimental.is_suitable_varinfo(
Expand All @@ -46,7 +46,7 @@
else
# Warn the user that we can't use the type stable one.
@warn "Model seems incompatible with typed varinfo. Falling back to untyped varinfo."
DynamicPPL.untyped_varinfo(model, context)
DynamicPPL.UntypedVarInfo(model, context)

Check warning on line 49 in ext/DynamicPPLJETExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DynamicPPLJETExt.jl#L49

Added line #L49 was not covered by tests
end
end

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export AbstractVarInfo,
UntypedVarInfo,
TypedVarInfo,
SimpleVarInfo,
NTVarInfo,
push!!,
empty!!,
subset,
Expand Down
6 changes: 3 additions & 3 deletions src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ julia> values_as(SimpleVarInfo(data), Vector)

```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());
vi = TypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());

julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

Expand All @@ -277,7 +277,7 @@ julia> values_as(vi, Vector)

```jldoctest
julia> # Just use an example model to construct the `VarInfo` because we're lazy.
vi = VarInfo(); DynamicPPL.TestUtils.demo_assume_dot_observe()(vi);
vi = UntypedVarInfo(DynamicPPL.TestUtils.demo_assume_dot_observe());

julia> vi[@varname(s)] = 1.0; vi[@varname(m)] = 2.0;

Expand Down Expand Up @@ -354,7 +354,7 @@ demo (generic function with 2 methods)

julia> model = demo();

julia> varinfo = VarInfo(model);
julia> varinfo = TypedVarInfo(model);

julia> keys(varinfo)
4-element Vector{VarName}:
Expand Down
13 changes: 13 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
@deprecate generated_quantities(model, params) returned(model, params)

Base.@deprecate VarInfo(
rng::Random.AbstractRNG,
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
) TypedVarInfo(rng, model, sampler, context)
Base.@deprecate VarInfo(
model::Model,
sampler::AbstractSampler=SampleFromPrior(),
context::AbstractContext=DefaultContext(),
) TypedVarInfo(model, sampler, context)
Base.@deprecate VarInfo(model::Model, context::AbstractContext) TypedVarInfo(model, context)
6 changes: 3 additions & 3 deletions src/experimental.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
┌ Warning: Model seems incompatible with typed varinfo. Falling back to untyped varinfo.
└ @ DynamicPPLJETExt ~/.julia/dev/DynamicPPL.jl/ext/DynamicPPLJETExt.jl:48

julia> vi isa typeof(DynamicPPL.untyped_varinfo(model))
julia> vi isa typeof(DynamicPPL.UntypedVarInfo(model))
true

julia> # In contrast, a simple model with no random support can be handled by typed varinfo.
Expand All @@ -81,7 +81,7 @@

julia> vi = determine_suitable_varinfo(model_with_static_support());

julia> vi isa typeof(DynamicPPL.typed_varinfo(model_with_static_support()))
julia> vi isa typeof(DynamicPPL.TypedVarInfo(model_with_static_support()))
true
```
"""
Expand All @@ -97,7 +97,7 @@
# Warn the user.
@warn "JET.jl is not loaded. Assumes the model is compatible with typed varinfo."
# Otherwise, we use the, possibly incorrect, default typed varinfo (to stay backwards compat).
DynamicPPL.typed_varinfo(model, context)
DynamicPPL.TypedVarInfo(model, context)

Check warning on line 100 in src/experimental.jl

View check run for this annotation

Codecov / codecov/patch

src/experimental.jl#L100

Added line #L100 was not covered by tests
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # This also respects the context in `model`.
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model));
f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), TypedVarInfo(model));

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
Expand Down Expand Up @@ -109,7 +109,7 @@ struct LogDensityFunction{

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
varinfo::AbstractVarInfo=TypedVarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
Expand Down
20 changes: 10 additions & 10 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ julia> conditioned(cm)

julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
keys(TypedVarInfo(cm))
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
a.m

Expand All @@ -446,7 +446,7 @@ julia> conditioned(cm)[@varname(x)]
julia> conditioned(cm)[@varname(a.m)]
1.0

julia> keys(VarInfo(cm)) # No variables are sampled
julia> keys(TypedVarInfo(cm)) # No variables are sampled
VarName[]
```
"""
Expand Down Expand Up @@ -773,7 +773,7 @@ julia> fixed(cm)

julia> # Since we fixed on `m`, not `a.m` as it will appear after prefixed,
# `a.m` is treated as a random variable.
keys(VarInfo(cm))
keys(TypedVarInfo(cm))
1-element Vector{VarName{:a, Accessors.PropertyLens{:m}}}:
a.m

Expand All @@ -786,7 +786,7 @@ julia> fixed(cm)[@varname(x)]
julia> fixed(cm)[@varname(a.m)]
1.0

julia> keys(VarInfo(cm)) # <= no variables are sampled
julia> keys(TypedVarInfo(cm)) # <= no variables are sampled
VarName[]
```
"""
Expand Down Expand Up @@ -1037,7 +1037,7 @@ julia> logjoint(demo_model([1., 2.]), chain);
```
"""
function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
var_info = TypedVarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
Expand Down Expand Up @@ -1084,7 +1084,7 @@ julia> logprior(demo_model([1., 2.]), chain);
```
"""
function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
var_info = TypedVarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
Expand Down Expand Up @@ -1131,7 +1131,7 @@ julia> loglikelihood(demo_model([1., 2.]), chain);
```
"""
function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractChains)
var_info = VarInfo(model) # extract variables info from the model
var_info = TypedVarInfo(model) # extract variables info from the model
map(Iterators.product(1:size(chain, 1), 1:size(chain, 3))) do (iteration_idx, chain_idx)
argvals_dict = OrderedDict(
vn_parent =>
Expand Down Expand Up @@ -1339,7 +1339,7 @@ julia> @model function demo2(x, y)

When we sample from the model `demo2(missing, 0.4)` random variable `x` will be sampled:
```jldoctest submodel-to_submodel
julia> vi = VarInfo(demo2(missing, 0.4));
julia> vi = TypedVarInfo(demo2(missing, 0.4));

julia> @varname(a.x) in keys(vi)
true
Expand Down Expand Up @@ -1376,7 +1376,7 @@ julia> @model function demo2_no_prefix(x, z)
return z ~ Uniform(-a, 1)
end;

julia> vi = VarInfo(demo2_no_prefix(missing, 0.4));
julia> vi = TypedVarInfo(demo2_no_prefix(missing, 0.4));

julia> @varname(x) in keys(vi) # here we just use `x` instead of `a.x`
true
Expand All @@ -1391,7 +1391,7 @@ julia> @model function demo2(x, y, z)
return z ~ Uniform(-a, b)
end;

julia> vi = VarInfo(demo2(missing, missing, 0.4));
julia> vi = TypedVarInfo(demo2(missing, missing, 0.4));

julia> @varname(sub1.x) in keys(vi)
true
Expand Down
12 changes: 6 additions & 6 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

Return `true` if all variable names in `model`/`varinfo` are in `chain`.
"""
varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain)
varnames_in_chain(model::Model, chain) = varnames_in_chain(TypedVarInfo(model), chain)

Check warning on line 7 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L7

Added line #L7 was not covered by tests
function varnames_in_chain(varinfo::VarInfo, chain)
return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo))
end
Expand All @@ -16,7 +16,7 @@
Return `out` with `true` for all variable names in `model` that are in `chain`.
"""
function varnames_in_chain!(model::Model, chain, out)
return varnames_in_chain!(VarInfo(model), chain, out)
return varnames_in_chain!(TypedVarInfo(model), chain, out)

Check warning on line 19 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L19

Added line #L19 was not covered by tests
end
function varnames_in_chain!(varinfo::VarInfo, chain, out)
for vn in keys(varinfo)
Expand All @@ -33,7 +33,7 @@
Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`.
"""
function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx)
return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx)
return varname_in_chain(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx)

Check warning on line 36 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L36

Added line #L36 was not covered by tests
end

function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx)
Expand All @@ -60,7 +60,7 @@
rather than a single boolean. This can be quite useful for debugging purposes.
"""
function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out)
return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out)
return varname_in_chain!(TypedVarInfo(model), vn, chain, chain_idx, iteration_idx, out)

Check warning on line 63 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L63

Added line #L63 was not covered by tests
end

function varname_in_chain!(
Expand Down Expand Up @@ -132,7 +132,7 @@
`chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out)
return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out)
return values_from_chain(TypedVarInfo(model), chain, chain_idx, iteration_idx, out)

Check warning on line 135 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L135

Added line #L135 was not covered by tests
end

function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out)
Expand Down Expand Up @@ -197,7 +197,7 @@
```
"""
function value_iterator_from_chain(model::Model, chain)
return value_iterator_from_chain(VarInfo(model), chain)
return value_iterator_from_chain(TypedVarInfo(model), chain)

Check warning on line 200 in src/model_utils.jl

View check run for this annotation

Codecov / codecov/patch

src/model_utils.jl#L200

Added line #L200 was not covered by tests
end

function value_iterator_from_chain(vi::AbstractVarInfo, chain)
Expand Down
Loading
Loading