diff --git a/HISTORY.md b/HISTORY.md index 3ea8071f3..eb6b19dd9 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,23 @@ # DynamicPPL Changelog +## 0.36.0 + +### Models now store AD backend types + +In `DynamicPPL.Model`, an extra field `adtype::Union{Nothing,ADTypes.AbstractADType}` has been added. +This field is used to store the AD backend which should be used when calculating gradients of the log density. + +The field can be set by passing an extra argument to the `Model` constructor, but more realistically, it is likely that you will want to manually set the `adtype` field on an existing model using `Model(::Model, ::AbstractADType)`: + +```julia +@model f() = ... +model = f() +model_with_adtype = Model(model, AutoForwardDiff()) +``` + +As far as `DynamicPPL.Model` is concerned, this field does not actually have any effect. +However, when a `LogDensityFunction` is constructed from said model, it will inherit the `adtype` field from the model. + ## 0.35.5 Several internal methods have been removed: @@ -174,7 +192,8 @@ Instead of constructing a `LogDensityProblemAD.ADgradient` object, we now direct Note that if you wish, you can still construct an `ADgradient` out of a `LogDensityFunction` object (there is nothing preventing this). However, in this version, `LogDensityFunction` now takes an extra AD type argument. -If this argument is not provided, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient. +By default, this AD type is inherited from the model that the `LogDensityFunction` is constructed from. +If the model does not have an AD type, or if the argument is explicitly set to `nothing`, the behaviour is exactly the same as before, i.e. you can calculate `logdensity` but not its gradient. However, if you do pass an AD type, that will allow you to calculate the gradient as well. You may thus find that it is easier to instead do this: diff --git a/Project.toml b/Project.toml index 05d33ec36..d5185d727 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.35.5" +version = "0.36.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..93223194a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -19,29 +19,24 @@ is_supported(::ADTypes.AutoReverseDiff) = true model::Model, varinfo::AbstractVarInfo=VarInfo(model), context::AbstractContext=DefaultContext(); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing + adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype, ) -A struct which contains a model, along with all the information necessary to: +A struct which contains a model, along with all the information necessary to +calculate its log density at a given point. - - calculate its log density at a given point; - - and if `adtype` is provided, calculate the gradient of the log density at - that point. +If the `adtype` keyword argument is specified, it is used to overwrite the +existing `adtype` in the model supplied. At its most basic level, a LogDensityFunction wraps the model together with its the type of varinfo to be used, as well as the evaluation context. These must be known in order to calculate the log density (using [`DynamicPPL.evaluate!!`](@ref)). -If the `adtype` keyword argument is provided, then this struct will also store -the adtype along with other information for efficient calculation of the -gradient of the log density. Note that preparing a `LogDensityFunction` with an -AD type `AutoBackend()` requires the AD backend itself to have been loaded -(e.g. with `import Backend`). - -`DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface. -If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a -concrete AD backend type, then `logdensity_and_gradient` is also implemented. +Using this information, `DynamicPPL.LogDensityFunction` implements the +LogDensityProblems.jl interface. If the underlying model's `adtype` is nothing, +then only `logdensity` is implemented. If the model's `adtype` is a concrete AD +backend type, then `logdensity_and_gradient` is also implemented. # Fields $(FIELDS) @@ -84,26 +79,30 @@ julia> # This also respects the context in `model`. julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) true -julia> # If we also need to calculate the gradient, we can specify an AD backend. +julia> # If we also need to calculate the gradient, an AD backend must be specified as part of the model. import ForwardDiff, ADTypes -julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); +julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff()); + +julia> f = LogDensityFunction(model_with_ad); + +julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) +(-2.3378770664093453, [1.0]) + +julia> # Alternatively, we can set the AD backend when creating the LogDensityFunction. + f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff()); julia> LogDensityProblems.logdensity_and_gradient(f, [0.0]) (-2.3378770664093453, [1.0]) ``` """ -struct LogDensityFunction{ - M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType} -} +struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} "model used for evaluation" model::M "varinfo used for evaluation" varinfo::V "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C - "AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated" - adtype::AD "(internal use only) gradient preparation object for the model" prep::Union{Nothing,DI.GradientPrep} @@ -111,13 +110,15 @@ struct LogDensityFunction{ model::Model, varinfo::AbstractVarInfo=VarInfo(model), context::AbstractContext=leafcontext(model.context); - adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, + adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype, ) if adtype === nothing prep = nothing else - # Make backend-specific tweaks to the adtype adtype = tweak_adtype(adtype, model, varinfo, context) + if adtype != model.adtype + model = Model(model, adtype) + end # Check whether it is supported is_supported(adtype) || @warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed." @@ -138,8 +139,8 @@ struct LogDensityFunction{ ) end end - return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + return new{typeof(model),typeof(varinfo),typeof(context)}( + model, varinfo, context, prep ) end end @@ -157,7 +158,7 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g function LogDensityFunction( f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType} ) - return if adtype === f.adtype + return if adtype === f.model.adtype f # Avoid recomputing prep if not needed else LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype) @@ -187,35 +188,46 @@ end ### LogDensityProblems interface function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,Nothing}} -) where {M,V,C} + ::Type{ + <:LogDensityFunction{ + Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C + }, + }, +) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,V,C,AD}} -) where {M,V,C,AD<:ADTypes.AbstractADType} + ::Type{ + <:LogDensityFunction{ + Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C + }, + }, +) where { + F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<:ADTypes.AbstractADType +} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector) return logdensity_at(x, f.model, f.varinfo, f.context) end function LogDensityProblems.logdensity_and_gradient( - f::LogDensityFunction{M,V,C,AD}, x::AbstractVector -) where {M,V,C,AD<:ADTypes.AbstractADType} - f.prep === nothing && - error("Gradient preparation not available; this should not happen") + f::LogDensityFunction{M,V,C}, x::AbstractVector +) where {M,V,C} + f.prep === nothing && error( + "Attempted to call logdensity_and_gradient on a LogDensityFunction without an AD backend. You need to set an AD backend in the model before calculating the gradient of logp.", + ) x = map(identity, x) # Concretise type # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) - return if use_closure(f.adtype) + return if use_closure(f.model.adtype) DI.value_and_gradient( - x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x + x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.model.adtype, x ) else DI.value_and_gradient( logdensity_at, f.prep, - f.adtype, + f.model.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo), @@ -292,7 +304,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. """ function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) - return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype) + return LogDensityFunction(model, f.varinfo, f.context) end """ diff --git a/src/model.jl b/src/model.jl index a0451b1b6..8f28a552d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,17 +1,23 @@ """ - struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext,TAD<:Union{Nothing,ADTypes.AbstractADType}} f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx=DefaultContext() + adtype::TAD=nothing end -A `Model` struct with model evaluation function of type `F`, arguments of names `argnames` -types `Targs`, default arguments of names `defaultnames` with types `Tdefaults`, missing -arguments `missings`, and evaluation context of type `Ctx`. +A `Model` struct contains the following fields: +- `f`, a model evaluation function of type `F` +- `args`, arguments of names `argnames` with types `Targs` +- `defaults`, default arguments of names `defaultnames` with types `Tdefaults` +- `context`, an evaluation context of type `Ctx` +- `adtype`, which can be nothing, or an automatic differentiation backend of type `TAD` +Its missing arguments are also stored as a type parameter `missings`. Here `argnames`, `defaultargnames`, and `missings` are tuples of symbols, e.g. `(:a, :b)`. -`context` is by default `DefaultContext()`. + +`context` is by default `DefaultContext()`, and `adtype` is by default `nothing`. An argument with a type of `Missing` will be in `missings` by default. However, in non-traditional use-cases `missings` can be defined differently. All variables in `missings` @@ -33,12 +39,21 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + missings, + Targs, + Tdefaults, + Ctx<:AbstractContext, + TAD<:Union{Nothing,ADTypes.AbstractADType}, +} <: AbstractProbabilisticProgram f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} context::Ctx + adtype::TAD @doc """ Model{missings}(f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +66,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractConte args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, context::Ctx=DefaultContext(), - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( - f, args, defaults, context + adtype::TAD=nothing, + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx,TAD} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD}( + f, args, defaults, context, adtype ) end end @@ -71,22 +87,40 @@ model with different arguments. args::NamedTuple{argnames,Targs}, defaults::NamedTuple{kwargnames,Tkwargs}, context::AbstractContext=DefaultContext(), -) where {F,argnames,Targs,kwargnames,Tkwargs} + adtype::TAD=nothing, +) where {F,argnames,Targs,kwargnames,Tkwargs,TAD} missing_args = Tuple( name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing ) missing_kwargs = Tuple( name for (name, typ) in zip(kwargnames, Tkwargs.types) if typ <: Missing ) - return :(Model{$(missing_args..., missing_kwargs...)}(f, args, defaults, context)) + return :(Model{$(missing_args..., missing_kwargs...)}( + f, args, defaults, context, adtype + )) end function Model(f, args::NamedTuple, context::AbstractContext=DefaultContext(); kwargs...) - return Model(f, args, NamedTuple(kwargs), context) + return Model(f, args, NamedTuple(kwargs), context, nothing) end +""" + Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType}) + +Create a new model with the same evaluation function and arguments as `model`, but with +automatic differentiation backend `adtype`. +""" +function Model(model::Model, adtype::Union{Nothing,ADTypes.AbstractADType}) + return Model(model.f, model.args, model.defaults, model.context, adtype) +end + +""" + contextualize(model::Model, context::AbstractContext) + +Set the context of `model` to `context`. +""" function contextualize(model::Model, context::AbstractContext) - return Model(model.f, model.args, model.defaults, context) + return Model(model.f, model.args, model.defaults, context, model.adtype) end """ diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index d6e66ec59..114b61a90 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -10,7 +10,8 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff end @testset "LogDensityFunction" begin - @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "construction from $(nameof(model))" for model in + DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) vns = DynamicPPL.TestUtils.varnames(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns) @@ -23,14 +24,28 @@ end end end - @testset "capabilities" begin - model = DynamicPPL.TestUtils.DEMO_MODELS[1] + @testset "LogDensityProblems interface" begin + @model demo_simple() = x ~ Normal() + model = demo_simple() + ldf = DynamicPPL.LogDensityFunction(model) @test LogDensityProblems.capabilities(typeof(ldf)) == LogDensityProblems.LogDensityOrder{0}() + @test LogDensityProblems.logdensity(ldf, [1.0]) isa Any - ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff()) + # Set AD type on model, then reconstruct LDF + model_with_ad = Model(model, AutoForwardDiff()) + ldf_with_ad = DynamicPPL.LogDensityFunction(model_with_ad) @test LogDensityProblems.capabilities(typeof(ldf_with_ad)) == LogDensityProblems.LogDensityOrder{1}() + @test LogDensityProblems.logdensity(ldf_with_ad, [1.0]) isa Any + @test LogDensityProblems.logdensity_and_gradient(ldf_with_ad, [1.0]) isa Any + + # Set AD type on LDF directly + ldf_with_ad2 = DynamicPPL.LogDensityFunction(ldf, AutoForwardDiff()) + @test LogDensityProblems.capabilities(typeof(ldf_with_ad2)) == + LogDensityProblems.LogDensityOrder{1}() + @test LogDensityProblems.logdensity(ldf_with_ad2, [1.0]) isa Any + @test LogDensityProblems.logdensity_and_gradient(ldf_with_ad2, [1.0]) isa Any end end diff --git a/test/model.jl b/test/model.jl index a863b6596..8094afcc5 100644 --- a/test/model.jl +++ b/test/model.jl @@ -100,6 +100,16 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal() end end + @testset "model adtype" begin + # Check that adtype can be set and unset + @model demo_adtype() = x ~ Normal() + adtype = AutoForwardDiff() + model = Model(demo_adtype(), adtype) + @test model.adtype == adtype + model = Model(model, nothing) + @test model.adtype === nothing + end + @testset "model de/conditioning" begin @model function demo_condition() x ~ Normal()