Skip to content

Commit dac729e

Browse files
committed
Add AD testing utilities
1 parent 1366440 commit dac729e

File tree

7 files changed

+231
-16
lines changed

7 files changed

+231
-16
lines changed

HISTORY.md

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@ This release removes the feature of `VarInfo` where it kept track of which varia
4949
5050
This change also affects sampling in Turing.jl.
5151
52+
### New features
53+
54+
The `DynamicPPL.TestUtils.AD` module now contains several functions for testing the correctness of automatic differentiation of log densities.
55+
Please refer to the DynamicPPL documentation for more details.
56+
5257
## 0.34.2
5358
5459
- Fixed bugs in ValuesAsInModelContext as well as DebugContext where underlying PrefixContexts were not being applied.

Project.toml

+3-5
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
1212
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1313
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1414
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
15+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1516
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1617
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1718
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
18-
# TODO(penelopeysm,mhauru) KernelAbstractions is only a dependency so that we can pin its version, see
19-
# https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
2019
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2120
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2221
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
@@ -56,14 +55,13 @@ Bijectors = "0.13.18, 0.14, 0.15"
5655
ChainRulesCore = "1"
5756
Compat = "4"
5857
ConstructionBase = "1.5.4"
58+
DifferentiationInterface = "0.6.39"
5959
Distributions = "0.25"
6060
DocStringExtensions = "0.9"
61-
# TODO(penelopeysm,mhauru) See https://github.com/TuringLang/DynamicPPL.jl/pull/781#event-16017866767
62-
# for why KernelAbstractions is pinned like this.
63-
KernelAbstractions = "< 0.9.32"
6461
EnzymeCore = "0.6 - 0.8"
6562
ForwardDiff = "0.10"
6663
JET = "0.9"
64+
KernelAbstractions = "< 0.9.32"
6765
LinearAlgebra = "1.6"
6866
LogDensityProblems = "2"
6967
LogDensityProblemsAD = "1.7.0"

docs/src/api.md

+9
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,15 @@ DynamicPPL.TestUtils.update_values!!
219219
DynamicPPL.TestUtils.test_values
220220
```
221221

222+
To test whether automatic differentiation is working correctly, the following methods can be used:
223+
224+
```@docs
225+
DynamicPPL.TestUtils.AD.ad_ldp
226+
DynamicPPL.TestUtils.AD.ad_di
227+
DynamicPPL.TestUtils.AD.make_function
228+
DynamicPPL.TestUtils.AD.make_params
229+
```
230+
222231
## Debugging Utilities
223232

224233
DynamicPPL provides a few methods for checking validity of a model-definition.

src/DynamicPPL.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ include("context_implementations.jl")
189189
include("compiler.jl")
190190
include("pointwise_logdensities.jl")
191191
include("submodel_macro.jl")
192-
include("test_utils.jl")
193192
include("transforming.jl")
194193
include("logdensityfunction.jl")
195194
include("model_utils.jl")
@@ -199,6 +198,8 @@ include("values_as_in_model.jl")
199198
include("debug_utils.jl")
200199
using .DebugUtils
201200

201+
include("test_utils.jl")
202+
202203
include("experimental.jl")
203204
include("deprecated.jl")
204205

src/test_utils.jl

+1
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@ include("test_utils/models.jl")
1818
include("test_utils/contexts.jl")
1919
include("test_utils/varinfo.jl")
2020
include("test_utils/sampler.jl")
21+
include("test_utils/ad.jl")
2122

2223
end

src/test_utils/ad.jl

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
module AD
2+
3+
import ADTypes: AbstractADType
4+
import DifferentiationInterface as DI
5+
import ..DynamicPPL: DynamicPPL, Model, LogDensityFunction, VarInfo, AbstractVarInfo
6+
import LogDensityProblems: logdensity, logdensity_and_gradient
7+
import LogDensityProblemsAD: ADgradient
8+
import Random: Random, AbstractRNG
9+
import Test: @test
10+
11+
export make_function, make_params, ad_ldp, ad_di, test_correctness
12+
13+
"""
14+
flipped_logdensity(θ, ldf)
15+
16+
Flips the order of arguments for `logdensity` to match the signature needed
17+
for DifferentiationInterface.jl.
18+
"""
19+
flipped_logdensity(θ, ldf) = logdensity(ldf, θ)
20+
21+
"""
22+
ad_ldp(
23+
model::Model,
24+
params::Vector{<:Real},
25+
adtype::AbstractADType,
26+
varinfo::AbstractVarInfo=VarInfo(model)
27+
)
28+
29+
Calculate the logdensity of `model` and its gradient using the AD backend
30+
`adtype`, evaluated at the parameters `params`, using the implementation of
31+
`logdensity_and_gradient` in the LogDensityProblemsAD.jl package.
32+
33+
The `varinfo` argument is optional and is used to provide the container
34+
structure for the parameters. Note that the _parameters_ inside the `varinfo`
35+
argument itself are overridden by the `params` argument. This argument defaults
36+
to [`DynamicPPL.VarInfo`](@ref), which is the default container structure used
37+
throughout the Turing ecosystem; however, you can provide e.g.
38+
[`DynamicPPL.SimpleVarInfo`](@ref) if you want to use a different container
39+
structure.
40+
41+
Returns a tuple `(value, gradient)` where `value <: Real` is the logdensity
42+
of the model evaluated at `params`, and `gradient <: Vector{<:Real}` is the
43+
gradient of the logdensity with respect to `params`.
44+
45+
Note that DynamicPPL.jl and Turing.jl currently use LogDensityProblemsAD.jl
46+
throughout, and hence this function most closely mimics the usage of AD within
47+
the Turing ecosystem.
48+
49+
For some AD backends such as Mooncake.jl, LogDensityProblemsAD.jl simply defers
50+
to the DifferentiationInterface.jl package. In such a case, `ad_ldp` simplifies
51+
to `ad_di` (in that if `ad_di` passes, one should expect `ad_ldp` to pass as
52+
well).
53+
54+
However, there are other AD backends which still have custom code in
55+
LogDensityProblemsAD.jl (such as ForwardDiff.jl). For these backends, `ad_di`
56+
may yield different results compared to `ad_ldp`, and the behaviour of `ad_di`
57+
is in such cases not guaranteed to be consistent with the behaviour of
58+
Turing.jl.
59+
60+
See also: [`ad_di`](@ref).
61+
"""
62+
function ad_ldp(
63+
model::Model,
64+
params::Vector{<:Real},
65+
adtype::AbstractADType,
66+
vi::AbstractVarInfo=VarInfo(model),
67+
)
68+
ldf = LogDensityFunction(model, vi)
69+
# Note that the implementation of logdensity takes care of setting the
70+
# parameters in vi to the correct values (using unflatten)
71+
return logdensity_and_gradient(ADgradient(adtype, ldf), params)
72+
end
73+
74+
"""
75+
ad_di(
76+
model::Model,
77+
params::Vector{<:Real},
78+
adtype::AbstractADType,
79+
varinfo::AbstractVarInfo=VarInfo(model)
80+
)
81+
82+
Calculate the logdensity of `model` and its gradient using the AD backend
83+
`adtype`, evaluated at the parameters `params`, directly using
84+
DifferentiationInterface.jl.
85+
86+
See the notes in [`ad_ldp`](@ref) for more details on the differences between
87+
`ad_di` and `ad_ldp`.
88+
"""
89+
function ad_di(
90+
model::Model,
91+
params::Vector{<:Real},
92+
adtype::AbstractADType,
93+
vi::AbstractVarInfo=VarInfo(model),
94+
)
95+
ldf = LogDensityFunction(model, vi)
96+
# Note that the implementation of logdensity takes care of setting the
97+
# parameters in vi to the correct values (using unflatten)
98+
prep = DI.prepare_gradient(flipped_logdensity, adtype, params, DI.Constant(ldf))
99+
return DI.value_and_gradient(flipped_logdensity, prep, adtype, params, DI.Constant(ldf))
100+
end
101+
102+
"""
103+
make_function(model, varinfo::AbstractVarInfo=VarInfo(model))
104+
105+
Generate the function to be differentiated. Specifically,
106+
`make_function(model)` returns a function which takes a single argument
107+
`params` and returns the logdensity of `model` evaluated at `params`.
108+
109+
The `varinfo` parameter is optional and is used to determine the structure of
110+
the varinfo used during evaluation. See the [`ad_ldp`](@ref) function for more
111+
details on the `varinfo` argument.
112+
113+
If you have an AD package that does not have integrations with either
114+
LogDensityProblemsAD.jl (in which case you can use [`ad_ldp`](@ref)) or
115+
DifferentiationInterface.jl (in which case you can use [`ad_di`](@ref)), you
116+
can test whether your AD package works with Turing.jl models using:
117+
118+
```julia
119+
f = make_function(model)
120+
params = make_params(model)
121+
value, grad = YourADPackage.gradient(f, params)
122+
```
123+
124+
and compare the results against that obtained from either `ad_ldp` or `ad_di` for
125+
an existing AD package that _is_ supported.
126+
127+
See also: [`make_params`](@ref).
128+
"""
129+
function make_function(model::Model, vi::AbstractVarInfo=VarInfo(model))
130+
# TODO: Can we simplify this even further by inlining the definition of
131+
# logdensity?
132+
return Base.Fix1(logdensity, LogDensityFunction(model, vi))
133+
end
134+
135+
"""
136+
make_params(model, rng::Random.AbstractRNG=Random.default_rng())
137+
138+
Generate a vector of parameters sampled from the prior distribution of `model`.
139+
This can be used as the input to the function to be differentiated. See
140+
[`make_function`](@ref) for more details.
141+
"""
142+
function make_params(model::Model, rng::AbstractRNG=Random.default_rng())
143+
return VarInfo(rng, model)[:]
144+
end
145+
146+
"""
147+
test_correctness(
148+
ad_function,
149+
model::Model,
150+
adtypes::Vector{<:ADTypes.AbstractADType},
151+
reference_adtype::ADTypes.AbstractADType,
152+
rng::Random.AbstractRNG=Random.default_rng(),
153+
params::Vector{<:Real}=VarInfo(rng, model)[:];
154+
value_atol=1e-6,
155+
grad_atol=1e-6
156+
)
157+
158+
Test the correctness of all the AD backend `adtypes` for the model `model`
159+
using the implementation `ad_function`. `ad_function` should be either
160+
[`ad_ldp`](@ref) or [`ad_di`](@ref), or a custom function that has the same
161+
signature.
162+
163+
The test is performed by calculating the logdensity and its gradient using all
164+
the AD backends, and comparing the results against that obtained with the
165+
reference AD backend `reference_adtype`.
166+
167+
The parameters can either be passed explicitly using the `params` argument, or can
168+
be sampled from the prior distribution of the model using the `rng` argument.
169+
"""
170+
function test_correctness(
171+
ad_function,
172+
model::Model,
173+
adtypes::Vector{<:AbstractADType},
174+
reference_adtype::AbstractADType,
175+
rng::AbstractRNG=Random.default_rng(),
176+
params::Vector{<:Real}=VarInfo(rng, model)[:];
177+
value_atol=1e-6,
178+
grad_atol=1e-6,
179+
)
180+
value_true, grad_true = ad_function(model, params, reference_adtype)
181+
for adtype in adtypes
182+
value, grad = ad_function(model, params, adtype)
183+
info_str = join(
184+
[
185+
"Testing AD correctness",
186+
" AD function : $(ad_function)",
187+
" backend : $(adtype)",
188+
" model : $(model.f)",
189+
" params : $(params)",
190+
" actual : $((value, grad))",
191+
" expected : $((value_true, grad_true))",
192+
],
193+
"\n",
194+
)
195+
@info info_str
196+
@test value value_true atol = value_atol
197+
@test grad grad_true atol = grad_atol
198+
end
199+
end
200+
201+
end # module DynamicPPL.TestUtils.AD

test/ad.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -6,29 +6,29 @@
66
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
77

88
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
9-
f = DynamicPPL.LogDensityFunction(m, varinfo)
10-
11-
# use ForwardDiff result as reference
12-
ad_forwarddiff_f = LogDensityProblemsAD.ADgradient(
13-
ADTypes.AutoForwardDiff(; chunksize=0), f
14-
)
159
# convert to `Vector{Float64}` to avoid `ReverseDiff` initializing the gradients to Integer 0
1610
# reference: https://github.com/TuringLang/DynamicPPL.jl/pull/571#issuecomment-1924304489
17-
θ = convert(Vector{Float64}, varinfo[:])
18-
logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ)
11+
params = convert(Vector{Float64}, varinfo[:])
12+
# Use ForwardDiff as reference AD backend
13+
ref_logp, ref_grad = DynamicPPL.TestUtils.AD.ad_ldp(
14+
m, params, ADTypes.AutoForwardDiff()
15+
)
1916

17+
# Test correctness of all other backends
2018
@testset "$adtype" for adtype in [
2119
ADTypes.AutoReverseDiff(; compile=false),
2220
ADTypes.AutoReverseDiff(; compile=true),
2321
ADTypes.AutoMooncake(; config=nothing),
2422
]
23+
@info "Testing AD correctness: $(m.f), $(adtype), $(short_varinfo_name(varinfo))"
24+
2525
# Mooncake can't currently handle something that is going on in
2626
# SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now.
2727
if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo
2828
@test_broken 1 == 0
2929
else
30-
ad_f = LogDensityProblemsAD.ADgradient(adtype, f)
31-
_, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ)
30+
logp, grad = DynamicPPL.TestUtils.AD.ad_ldp(m, params, adtype)
31+
@test logp ref_logp
3232
@test grad ref_grad
3333
end
3434
end

0 commit comments

Comments
 (0)