Skip to content

Commit 04bcecd

Browse files
committed
Remove separate adtype field from LogDensityFunction
1 parent f96dc3c commit 04bcecd

File tree

2 files changed

+47
-41
lines changed

2 files changed

+47
-41
lines changed

src/logdensityfunction.jl

+28-21
Original file line numberDiff line numberDiff line change
@@ -93,31 +93,28 @@ julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9393
(-2.3378770664093453, [1.0])
9494
```
9595
"""
96-
struct LogDensityFunction{
97-
M<:Model,V<:AbstractVarInfo,C<:AbstractContext,AD<:Union{Nothing,ADTypes.AbstractADType}
98-
}
96+
struct LogDensityFunction{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}
9997
"model used for evaluation"
10098
model::M
10199
"varinfo used for evaluation"
102100
varinfo::V
103101
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
104102
context::C
105-
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
106-
adtype::AD
107103
"(internal use only) gradient preparation object for the model"
108104
prep::Union{Nothing,DI.GradientPrep}
109105

110106
function LogDensityFunction(
111107
model::Model,
112108
varinfo::AbstractVarInfo=VarInfo(model),
113-
context::AbstractContext=leafcontext(model.context);
114-
adtype::Union{ADTypes.AbstractADType,Nothing}=model.adtype,
109+
context::AbstractContext=leafcontext(model.context),
115110
)
111+
adtype = model.adtype
116112
if adtype === nothing
117113
prep = nothing
118114
else
119115
# Make backend-specific tweaks to the adtype
120116
adtype = tweak_adtype(adtype, model, varinfo, context)
117+
model = Model(model, adtype)
121118
# Check whether it is supported
122119
is_supported(adtype) ||
123120
@warn "The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -138,8 +135,8 @@ struct LogDensityFunction{
138135
)
139136
end
140137
end
141-
return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}(
142-
model, varinfo, context, adtype, prep
138+
return new{typeof(model),typeof(varinfo),typeof(context)}(
139+
model, varinfo, context, prep
143140
)
144141
end
145142
end
@@ -157,10 +154,10 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
157154
function LogDensityFunction(
158155
f::LogDensityFunction, adtype::Union{Nothing,ADTypes.AbstractADType}
159156
)
160-
return if adtype === f.adtype
157+
return if adtype === f.model.adtype
161158
f # Avoid recomputing prep if not needed
162159
else
163-
LogDensityFunction(f.model, f.varinfo, f.context; adtype=adtype)
160+
LogDensityFunction(Model(f.model, adtype), f.varinfo, f.context)
164161
end
165162
end
166163

@@ -187,35 +184,45 @@ end
187184
### LogDensityProblems interface
188185

189186
function LogDensityProblems.capabilities(
190-
::Type{<:LogDensityFunction{M,V,C,Nothing}}
191-
) where {M,V,C}
187+
::Type{
188+
<:LogDensityFunction{
189+
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
190+
},
191+
},
192+
) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
192193
return LogDensityProblems.LogDensityOrder{0}()
193194
end
194195
function LogDensityProblems.capabilities(
195-
::Type{<:LogDensityFunction{M,V,C,AD}}
196-
) where {M,V,C,AD<:ADTypes.AbstractADType}
196+
::Type{
197+
<:LogDensityFunction{
198+
Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
199+
},
200+
},
201+
) where {
202+
F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<:ADTypes.AbstractADType
203+
}
197204
return LogDensityProblems.LogDensityOrder{1}()
198205
end
199206
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
200207
return logdensity_at(x, f.model, f.varinfo, f.context)
201208
end
202209
function LogDensityProblems.logdensity_and_gradient(
203-
f::LogDensityFunction{M,V,C,AD}, x::AbstractVector
204-
) where {M,V,C,AD<:ADTypes.AbstractADType}
210+
f::LogDensityFunction{M,V,C}, x::AbstractVector
211+
) where {M,V,C}
205212
f.prep === nothing &&
206213
error("Gradient preparation not available; this should not happen")
207214
x = map(identity, x) # Concretise type
208215
# Make branching statically inferrable, i.e. type-stable (even if the two
209216
# branches happen to return different types)
210-
return if use_closure(f.adtype)
217+
return if use_closure(f.model.adtype)
211218
DI.value_and_gradient(
212-
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x
219+
x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.model.adtype, x
213220
)
214221
else
215222
DI.value_and_gradient(
216223
logdensity_at,
217224
f.prep,
218-
f.adtype,
225+
f.model.adtype,
219226
x,
220227
DI.Constant(f.model),
221228
DI.Constant(f.varinfo),
@@ -292,7 +299,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292299
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293300
"""
294301
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
295-
return LogDensityFunction(model, f.varinfo, f.context; adtype=f.adtype)
302+
return LogDensityFunction(model, f.varinfo, f.context)
296303
end
297304

298305
"""

test/logdensityfunction.jl

+19-20
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,9 @@ using Test, DynamicPPL, ADTypes, LogDensityProblems, ForwardDiff
99
end
1010
end
1111

12-
@testset "AD type forwarding from model" begin
13-
@model demo_simple() = x ~ Normal()
14-
model = Model(demo_simple(), AutoForwardDiff())
15-
ldf = DynamicPPL.LogDensityFunction(model)
16-
# Check that the model's AD type is forwarded to the LDF
17-
# Note: can't check ldf.adtype == AutoForwardDiff() because `tweak_adtype`
18-
# modifies the underlying parameters a bit, so just check that it is still
19-
# the correct backend package.
20-
@test ldf.adtype isa AutoForwardDiff
21-
# Check that the gradient can be evaluated on the resulting LDF
22-
@test LogDensityProblems.capabilities(typeof(ldf)) ==
23-
LogDensityProblems.LogDensityOrder{1}()
24-
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
25-
@test LogDensityProblems.logdensity_and_gradient(ldf, [1.0]) isa Any
26-
end
27-
2812
@testset "LogDensityFunction" begin
29-
@testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS
13+
@testset "construction from $(nameof(model))" for model in
14+
DynamicPPL.TestUtils.DEMO_MODELS
3015
example_values = DynamicPPL.TestUtils.rand_prior_true(model)
3116
vns = DynamicPPL.TestUtils.varnames(model)
3217
varinfos = DynamicPPL.TestUtils.setup_varinfos(model, example_values, vns)
@@ -39,14 +24,28 @@ end
3924
end
4025
end
4126

42-
@testset "capabilities" begin
43-
model = DynamicPPL.TestUtils.DEMO_MODELS[1]
27+
@testset "LogDensityProblems interface" begin
28+
@model demo_simple() = x ~ Normal()
29+
model = demo_simple()
30+
4431
ldf = DynamicPPL.LogDensityFunction(model)
4532
@test LogDensityProblems.capabilities(typeof(ldf)) ==
4633
LogDensityProblems.LogDensityOrder{0}()
34+
@test LogDensityProblems.logdensity(ldf, [1.0]) isa Any
4735

48-
ldf_with_ad = DynamicPPL.LogDensityFunction(model; adtype=AutoForwardDiff())
36+
# Set AD type on model, then reconstruct LDF
37+
model_with_ad = Model(model, AutoForwardDiff())
38+
ldf_with_ad = DynamicPPL.LogDensityFunction(model_with_ad)
4939
@test LogDensityProblems.capabilities(typeof(ldf_with_ad)) ==
5040
LogDensityProblems.LogDensityOrder{1}()
41+
@test LogDensityProblems.logdensity(ldf_with_ad, [1.0]) isa Any
42+
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad, [1.0]) isa Any
43+
44+
# Set AD type on LDF directly
45+
ldf_with_ad2 = DynamicPPL.LogDensityFunction(ldf, AutoForwardDiff())
46+
@test LogDensityProblems.capabilities(typeof(ldf_with_ad2)) ==
47+
LogDensityProblems.LogDensityOrder{1}()
48+
@test LogDensityProblems.logdensity(ldf_with_ad2, [1.0]) isa Any
49+
@test LogDensityProblems.logdensity_and_gradient(ldf_with_ad2, [1.0]) isa Any
5150
end
5251
end

0 commit comments

Comments
 (0)