Skip to content

Implement AD testing and benchmarking (with DITest) #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand All @@ -35,6 +36,7 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLDifferentiationInterfaceTestExt = ["DifferentiationInterfaceTest"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
Expand All @@ -52,6 +54,7 @@ ChainRulesCore = "1"
Compat = "4"
ConstructionBase = "1.5.4"
DifferentiationInterface = "0.6.41"
DifferentiationInterfaceTest = "0.9.6"
Distributions = "0.25"
DocStringExtensions = "0.9"
EnzymeCore = "0.6 - 0.8"
Expand Down
87 changes: 87 additions & 0 deletions ext/DynamicPPLDifferentiationInterfaceTestExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
module DynamicPPLDifferentiationInterfaceTestExt

using DynamicPPL:
DynamicPPL,
ADTypes,
LogDensityProblems,
Model,
DI, # DifferentiationInterface
AbstractVarInfo,
VarInfo,
LogDensityFunction
import DifferentiationInterfaceTest as DIT

"""
REFERENCE_ADTYPE

Reference AD backend to use for comparison. In this case, ForwardDiff.jl, since
it's the default AD backend used in Turing.jl.
"""
const REFERENCE_ADTYPE = ADTypes.AutoForwardDiff()

"""
make_scenario(
model::Model,
adtype::ADTypes.AbstractADType,
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing,
)

Construct a DifferentiationInterfaceTest.Scenario for the given `model` and `adtype`.

More docs to follow.
"""
function make_scenario(
model::Model,
adtype::ADTypes.AbstractADType;
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
)
params = map(identity, params)
context = DynamicPPL.DefaultContext()
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo, context)
# Below is a performance optimisation, see: https://github.com/TuringLang/DynamicPPL.jl/pull/806#issuecomment-2658049143
if DynamicPPL.use_closure(adtype)
f = x -> DynamicPPL.logdensity_at(x, model, varinfo, context)
di_contexts = ()
else
f = DynamicPPL.logdensity_at
di_contexts = (DI.Constant(model), DI.Constant(varinfo), DI.Constant(context))
end

# Calculate ground truth to compare against
grad_true = if expected_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
LogDensityProblems.logdensity_and_gradient(ldf_reference, params)[2]
else
expected_grad
end

return DIT.Scenario{:gradient,:out}(
f, params; contexts=di_contexts, res1=grad_true, name="$(model.f)"
)
end

function DynamicPPL.TestUtils.AD.run_ad(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative is to overload DI.test_differentiation so we can tweak adtype internally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIT.Scenario doesn't contain a type parameter that we can dispatch on, though.

(While putting the ADTests bit together, I also found out that Scenarios can't be prepared with a specific value of x: JuliaDiff/DifferentiationInterface.jl#771)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could do something like test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...) and have that construct the scenario

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could do something like test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...) and have that construct the scenario

I like this idea. One could have both

  • test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...)
  • benchmark_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...)

Having unified interfaces across DI and Turing for these autodiff tests would be nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep in mind that DIT was designed mostly to test DI itself, so its interface is still rather dirty and unstable. Also, DIT.test_differentiation does way more than you probably need here. But if there is interest in standardization, we can take a look

Copy link
Member

@yebai yebai Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @gdalle, for the context.

It is an excellent idea to improve DIT so it can become a community resource like DI. It's very helpful to have a standard interface where

  • packages (like DynamicPPL, Bijectors, Distributions) can register test scenerios for AD backends
  • AD backends can run registered test scenarios easily

It would help the autodiff dev community discover bugs more quickly. It would also inform the general users which AD backend is likely compatible with the library (e.g. Lux, Turing) they want to use (see, e.g. https://turinglang.org/ADTests/)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIT is in the weird position where it simultaneously does much more than what we need and also doesn't do some of the things we need. I've said this elsewhere (in meetings etc) but this isn't a criticism of DIT, it's just about choosing the right tool for the job IMO.

model::Model,
adtype::ADTypes.AbstractADType;
varinfo::AbstractVarInfo=VarInfo(model),
params::Vector{<:Real}=varinfo[:],
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_grad::Union{Nothing,Vector{<:Real}}=nothing,
kwargs...,
)
scen = make_scenario(model, adtype; varinfo=varinfo, expected_grad=expected_grad)
tweaked_adtype = DynamicPPL.tweak_adtype(
adtype, model, varinfo, DynamicPPL.DefaultContext()
)
return DIT.test_differentiation(
tweaked_adtype, [scen]; scenario_intact=false, kwargs...
)
end

end
4 changes: 4 additions & 0 deletions src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,8 @@ include("test_utils/contexts.jl")
include("test_utils/varinfo.jl")
include("test_utils/sampler.jl")

module AD
function run_ad end
end

end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand Down
10 changes: 5 additions & 5 deletions test/ad.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DynamicPPL: LogDensityFunction
import DifferentiationInterfaceTest as DIT

@testset "Automatic differentiation" begin
# Used as the ground truth that others are compared against.
Expand Down Expand Up @@ -27,7 +28,7 @@ using DynamicPPL: LogDensityFunction
x = DynamicPPL.getparams(f)
# Calculate reference logp + gradient of logp using ForwardDiff
ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype)
ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)
ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x)[2]

@testset "$adtype" for adtype in test_adtypes
@info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype"
Expand Down Expand Up @@ -56,10 +57,9 @@ using DynamicPPL: LogDensityFunction
ref_ldf, adtype
)
else
ldf = DynamicPPL.LogDensityFunction(ref_ldf, adtype)
logp, grad = LogDensityProblems.logdensity_and_gradient(ldf, x)
@test grad ≈ ref_grad
@test logp ≈ ref_logp
DynamicPPL.TestUtils.AD.run_ad(
m, adtype; varinfo=varinfo, expected_grad=ref_grad
)
end
end
end
Expand Down