Skip to content

Commit dc95f97

Browse files
committed
Re-add the LogDensityFunction(...; adtype) method
1 parent d0dba32 commit dc95f97

File tree

1 file changed

+22
-9
lines changed

1 file changed

+22
-9
lines changed

src/logdensityfunction.jl

+22-9
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,26 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818
LogDensityFunction(
1919
model::Model,
2020
varinfo::AbstractVarInfo=VarInfo(model),
21-
context::AbstractContext=DefaultContext()
21+
context::AbstractContext=DefaultContext();
22+
adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
2223
)
2324
2425
A struct which contains a model, along with all the information necessary to
2526
calculate its log density at a given point.
2627
28+
If the `adtype` keyword argument is specified, it is used to overwrite the
29+
existing `adtype` in the model supplied.
30+
2731
At its most basic level, a LogDensityFunction wraps the model together with its
2832
the type of varinfo to be used, as well as the evaluation context. These must
2933
be known in order to calculate the log density (using
3034
[`DynamicPPL.evaluate!!`](@ref)).
3135
36+
Using this information, `DynamicPPL.LogDensityFunction` implements the
37+
LogDensityProblems.jl interface. If the underlying model's `adtype` is nothing,
38+
then only `logdensity` is implemented. If the model's `adtype` is a concrete AD
39+
backend type, then `logdensity_and_gradient` is also implemented.
40+
3241
# Fields
3342
$(FIELDS)
3443
@@ -77,6 +86,12 @@ julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
7786
7887
julia> f = LogDensityFunction(model_with_ad);
7988
89+
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
90+
(-2.3378770664093453, [1.0])
91+
92+
julia> # Alternatively, we can set the AD backend when creating the LogDensityFunction.
93+
f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
94+
8095
julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
8196
(-2.3378770664093453, [1.0])
8297
```
@@ -94,18 +109,16 @@ struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
94109
function LogDensityFunction(
95110
model::Model,
96111
varinfo::AbstractVarInfo=VarInfo(model),
97-
context::AbstractContext=leafcontext(model.context),
112+
context::AbstractContext=leafcontext(model.context);
113+
adtype::Union{Nothing,ADTypes.AbstractADType}=model.adtype,
98114
)
99-
adtype = model.adtype
100115
if adtype === nothing
101116
prep = nothing
102117
else
103-
# Make backend-specific tweaks to the adtype
104-
# This should arguably be done in the model constructor, but it needs the
105-
# varinfo and context to do so, and it seems excessive to construct a
106-
# varinfo at the point of calling Model().
107118
adtype = tweak_adtype(adtype, model, varinfo, context)
108-
model = Model(model, adtype)
119+
if adtype != model.adtype
120+
model = Model(model, adtype)
121+
end
109122
# Check whether it is supported
110123
is_supported(adtype) ||
111124
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -148,7 +161,7 @@ function LogDensityFunction(
148161
return if adtype === f.model.adtype
149162
f # Avoid recomputing prep if not needed
150163
else
151-
LogDensityFunction(Model(f.model, adtype), f.varinfo, f.context)
164+
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
152165
end
153166
end
154167

0 commit comments

Comments
 (0)