Skip to content

Add adtype to DynamicPPL.Model #818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
# DynamicPPL Changelog

## 0.36.0

### Models now store AD backend types

In `DynamicPPL.Model`, an extra field `adtype::Union{Nothing,ADTypes.AbstractADType}` has been added.
This field is used to store the AD backend which should be used when calculating gradients of the log density.

The field can be set by passing an extra argument to the `Model` constructor, but more realistically, it is likely that you will want to manually set the `adtype` field on an existing model using `Model(::Model, ::AbstractADType)`:

```julia
@model f() = ...
model = f()
model_with_adtype = Model(model, AutoForwardDiff())
```

As far as `DynamicPPL.Model` is concerned, this field does not actually have any effect.
However, when a `LogDensityFunction` is constructed from said model, it will inherit the `adtype` field from the model.

## 0.35.5

Several internal methods have been removed:
Expand Down Expand Up @@ -174,7 +192,8 @@ Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now direct
Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this).

However, in this version, `LogDensityFunction` now takes an extra AD type argument.
If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
By default, this AD type is inherited from the model that the `LogDensityFunction` is constructed from.
If the model does not have an AD type, or if the argument is explicitly set to `nothing`, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient.
However, if you do pass an AD type, that will allow you to calculate the gradient as well.
You may thus find that it is easier to instead do this:

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.35.5"
version = "0.36.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
88 changes: 50 additions & 38 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,24 @@ is_supported(::ADTypes.AutoReverseDiff) = true
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=DefaultContext();
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
)

A struct which contains a model, along with all the information necessary to:
A struct which contains a model, along with all the information necessary to
calculate its log density at a given point.

- calculate its log density at a given point;
- and if `adtype` is provided, calculate the gradient of the log density at
that point.
If the `adtype` keyword argument is specified, it is used to overwrite the
existing `adtype` in the model supplied.

At its most basic level, a LogDensityFunction wraps the model together with its
the type of varinfo to be used, as well as the evaluation context. These must
be known in order to calculate the log density (using
[`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
gradient of the log density. Note that preparing a `LogDensityFunction` with an
AD type `AutoBackend()` requires the AD backend itself to have been loaded
(e.g. with `import Backend`).

`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
concrete AD backend type, then `logdensity_and_gradient` is also implemented.
Using this information, `DynamicPPL.LogDensityFunction` implements the
LogDensityProblems.jl interface. If the underlying model's `adtype` is nothing,
then only `logdensity` is implemented. If the model's `adtype` is a concrete AD
backend type, then `logdensity_and_gradient` is also implemented.

# Fields
$(FIELDS)
Expand Down Expand Up @@ -84,40 +79,46 @@ julia> # This also respects the context in `model`.
julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true

julia> # If we also need to calculate the gradient, we can specify an AD backend.
julia> # If we also need to calculate the gradient, an AD backend must be specified as part of the model.
import ForwardDiff, ADTypes

julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());

julia> f = LogDensityFunction(model_with_ad);

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
(-2.3378770664093453, [1.0])

julia> # Alternatively, we can set the AD backend when creating the LogDensityFunction.
f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());

julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
(-2.3378770664093453, [1.0])
```
"""
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
}
struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
"model used for evaluation"
model::M
"varinfo used for evaluation"
varinfo::V
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
adtype::AD
"(internal use only) gradient preparation object for the model"
prep::Union{Nothing,DI.GradientPrep}

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
)
if adtype === nothing
prep = nothing
else
# Make backend-specific tweaks to the adtype
adtype = tweak_adtype(adtype, model, varinfo, context)
if adtype != model.adtype
model = Model(model, adtype)
end
# Check whether it is supported
is_supported(adtype) ||
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
Expand All @@ -138,8 +139,8 @@ struct LogDensityFunction{
)
end
end
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
model, varinfo, context, adtype, prep
return new{typeof(model),typeof(varinfo),typeof(context)}(
model, varinfo, context, prep
)
end
end
Expand All @@ -157,7 +158,7 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
function LogDensityFunction(
f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}
)
return if adtype === f.adtype
return if adtype === f.model.adtype
f # Avoid recomputing prep if not needed
else
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
Expand Down Expand Up @@ -187,35 +188,46 @@ end
### LogDensityProblems interface

function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,Nothing}}
) where {M,V,C}
::Type{
<:LogDensityFunction{
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
},
},
) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,C,AD}}
) where {M,V,C,AD<:ADTypes.AbstractADType}
::Type{
<:LogDensityFunction{
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
},
},
) where {
F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<:ADTypes.AbstractADType
}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
return logdensity_at(x, f.model, f.varinfo, f.context)
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
) where {M,V,C,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
f::LogDensityFunction{M,V,C}, x::AbstractVector
) where {M,V,C}
f.prep === nothing && error(
"Attempted to call logdensity_and_gradient on a LogDensityFunction without an AD backend. You need to set an AD backend in the model before calculating the gradient of logp.",
)
x = map(identity, x) # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
return if use_closure(f.model.adtype)
DI.value_and_gradient(
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.model.adtype, x
)
else
DI.value_and_gradient(
logdensity_at,
f.prep,
f.adtype,
f.model.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.varinfo),
Expand Down Expand Up @@ -292,7 +304,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
return LogDensityFunction(model, f.varinfo, f.context)
end

"""
Expand Down
62 changes: 48 additions & 14 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,23 @@
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext}
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,TAD<:Union{Nothing,ADTypes.AbstractADType}}
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx=DefaultContext()
adtype::TAD=nothing
end

A `Model` struct with model evaluation function of type `F`, arguments of names `argnames`
types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing
arguments `missings`, and evaluation context of type `Ctx`.
A `Model` struct contains the following fields:
- `f`, a model evaluation function of type `F`
- `args`, arguments of names `argnames` with types `Targs`
- `defaults`, default arguments of names `defaultnames` with types `Tdefaults`
- `context`, an evaluation context of type `Ctx`
- `adtype`, which can be nothing, or an automatic differentiation backend of type `TAD`

Its missing arguments are also stored as a type parameter `missings`.
Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`.
`context` is by default `DefaultContext()`.

`context` is by default `DefaultContext()`, and `adtype` is by default `nothing`.

An argument with a type of `Missing` will be in `missings` by default. However, in
non-traditional use-cases `missings` can be defined differently. All variables in `missings`
Expand All @@ -33,12 +39,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <:
AbstractProbabilisticProgram
struct Model{
F,
argnames,
defaultnames,
missings,
Targs,
Tdefaults,
Ctx<:AbstractContext,
TAD<:Union{Nothing,ADTypes.AbstractADType},
} <: AbstractProbabilisticProgram
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx
adtype::TAD

@doc """
Model{missings}(f, args::NamedTuple, defaults::NamedTuple)
Expand All @@ -51,9 +66,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
context::Ctx=DefaultContext(),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}(
f, args, defaults, context
adtype::TAD=nothing,
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,TAD}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD}(
f, args, defaults, context, adtype
)
end
end
Expand All @@ -71,22 +87,40 @@ model with different arguments.
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{kwargnames,Tkwargs},
context::AbstractContext=DefaultContext(),
) where {F,argnames,Targs,kwargnames,Tkwargs}
adtype::TAD=nothing,
) where {F,argnames,Targs,kwargnames,Tkwargs,TAD}
missing_args = Tuple(
name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing
)
missing_kwargs = Tuple(
name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing
)
return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context))
return :(Model{$(missing_args..., missing_kwargs...)}(
f, args, defaults, context, adtype
))
end

function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...)
return Model(f, args, NamedTuple(kwargs), context)
return Model(f, args, NamedTuple(kwargs), context, nothing)
end

"""
Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})

Create a new model with the same evaluation function and arguments as `model`, but with
automatic differentiation backend `adtype`.
"""
function Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType})
return Model(model.f, model.args, model.defaults, model.context, adtype)
end

"""
contextualize(model::Model, context::AbstractContext)

Set the context of `model` to `context`.
"""
function contextualize(model::Model, context::AbstractContext)
return Model(model.f, model.args, model.defaults, context)
return Model(model.f, model.args, model.defaults, context, model.adtype)
end

"""
Expand Down
23 changes: 19 additions & 4 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
end

@testset "LogDensityFunction" begin
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
@testset "construction from $(nameof(model))" for model in
DynamicPPL.TestUtils.DEMO_MODELS
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
vns = DynamicPPL.TestUtils.varnames(model)
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
Expand All @@ -23,14 +24,28 @@ end
end
end

@testset "capabilities" begin
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
@testset "LogDensityProblems interface" begin
@model demo_simple() = x ~ Normal()
model = demo_simple()

ldf = DynamicPPL.LogDensityFunction(model)
@test LogDensityProblems.capabilities(typeof(ldf)) ==
LogDensityProblems.LogDensityOrder{0}()
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any

ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff())
# Set AD type on model, then reconstruct LDF
model_with_ad = Model(model, AutoForwardDiff())
ldf_with_ad = DynamicPPL.LogDensityFunction(model_with_ad)
@test LogDensityProblems.capabilities(typeof(ldf_with_ad)) ==
LogDensityProblems.LogDensityOrder{1}()
@test LogDensityProblems.logdensity(ldf_with_ad, [1.0]) isa Any
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad, [1.0]) isa Any

# Set AD type on LDF directly
ldf_with_ad2 = DynamicPPL.LogDensityFunction(ldf, AutoForwardDiff())
@test LogDensityProblems.capabilities(typeof(ldf_with_ad2)) ==
LogDensityProblems.LogDensityOrder{1}()
@test LogDensityProblems.logdensity(ldf_with_ad2, [1.0]) isa Any
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad2, [1.0]) isa Any
end
end
10 changes: 10 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,16 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
end
end

@testset "model adtype" begin
# Check that adtype can be set and unset
@model demo_adtype() = x ~ Normal()
adtype = AutoForwardDiff()
model = Model(demo_adtype(), adtype)
@test model.adtype == adtype
model = Model(model, nothing)
@test model.adtype === nothing
end

@testset "model de/conditioning" begin
@model function demo_condition()
x ~ Normal()
Expand Down
Loading