@@ -18,17 +18,26 @@ is_supported(::ADTypes.AutoReverseDiff) = true
18
18
LogDensityFunction(
19
19
model::Model,
20
20
varinfo::AbstractVarInfo=VarInfo(model),
21
- context::AbstractContext=DefaultContext()
21
+ context::AbstractContext=DefaultContext();
22
+ adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
22
23
)
23
24
24
25
A struct which contains a model, along with all the information necessary to
25
26
calculate its log density at a given point.
26
27
28
+ If the `adtype` keyword argument is specified, it is used to overwrite the
29
+ existing `adtype` in the model supplied.
30
+
27
31
At its most basic level, a LogDensityFunction wraps the model together with its
28
32
the type of varinfo to be used, as well as the evaluation context. These must
29
33
be known in order to calculate the log density (using
30
34
[`DynamicPPL.evaluate!!`](@ref)).
31
35
36
+ Using this information, `DynamicPPL.LogDensityFunction` implements the
37
+ LogDensityProblems.jl interface. If the underlying model's `adtype` is nothing,
38
+ then only `logdensity` is implemented. If the model's `adtype` is a concrete AD
39
+ backend type, then `logdensity_and_gradient` is also implemented.
40
+
32
41
# Fields
33
42
$(FIELDS)
34
43
@@ -77,6 +86,12 @@ julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
77
86
78
87
julia> f = LogDensityFunction(model_with_ad);
79
88
89
+ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
90
+ (-2.3378770664093453, [1.0])
91
+
92
+ julia> # Alternatively, we can set the AD backend when creating the LogDensityFunction.
93
+ f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
94
+
80
95
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
81
96
(-2.3378770664093453, [1.0])
82
97
```
@@ -94,18 +109,16 @@ struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
94
109
function LogDensityFunction (
95
110
model:: Model ,
96
111
varinfo:: AbstractVarInfo = VarInfo (model),
97
- context:: AbstractContext = leafcontext (model. context),
112
+ context:: AbstractContext = leafcontext (model. context);
113
+ adtype:: Union{Nothing,ADTypes.AbstractADType} = model. adtype,
98
114
)
99
- adtype = model. adtype
100
115
if adtype === nothing
101
116
prep = nothing
102
117
else
103
- # Make backend-specific tweaks to the adtype
104
- # This should arguably be done in the model constructor, but it needs the
105
- # varinfo and context to do so, and it seems excessive to construct a
106
- # varinfo at the point of calling Model().
107
118
adtype = tweak_adtype (adtype, model, varinfo, context)
108
- model = Model (model, adtype)
119
+ if adtype != model. adtype
120
+ model = Model (model, adtype)
121
+ end
109
122
# Check whether it is supported
110
123
is_supported (adtype) ||
111
124
@warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -148,7 +161,7 @@ function LogDensityFunction(
148
161
return if adtype === f. model. adtype
149
162
f # Avoid recomputing prep if not needed
150
163
else
151
- LogDensityFunction (Model ( f. model, adtype), f. varinfo, f. context)
164
+ LogDensityFunction (f. model, f. varinfo, f. context; adtype = adtype )
152
165
end
153
166
end
154
167
0 commit comments