diff --git a/HISTORY.md b/HISTORY.md index a21258ec0..a45644a64 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,18 @@ **Breaking changes** +### AD testing utilities + +`DynamicPPL.TestUtils.AD.run_ad` now links the VarInfo by default. +To disable this, pass the `linked=false` keyword argument. +If the calculated value or gradient is incorrect, it also throws a `DynamicPPL.TestUtils.AD.ADIncorrectException` rather than a test failure. +This exception contains the actual and expected gradient so you can inspect it if needed; see the documentation for more information. +From a practical perspective, this means that if you need to add this to a test suite, you need to use `@test run_ad(...) isa Any` rather than just `run_ad(...)`. + +### SimpleVarInfo linking / invlinking + +Linking a linked SimpleVarInfo, or invlinking an unlinked SimpleVarInfo, now displays a warning instead of an error. + ### VarInfo constructors `VarInfo(vi::VarInfo, values)` has been removed. You can replace this directly with `unflatten(vi, values)` instead. diff --git a/docs/src/api.md b/docs/src/api.md index 2c61f54fc..ec741c9ad 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -212,6 +212,7 @@ To test and/or benchmark the performance of an AD backend on a model, DynamicPPL ```@docs DynamicPPL.TestUtils.AD.run_ad DynamicPPL.TestUtils.AD.ADResult +DynamicPPL.TestUtils.AD.ADIncorrectException ``` ## Demo models diff --git a/src/test_utils/ad.jl b/src/test_utils/ad.jl index 06c76df5e..d38915c12 100644 --- a/src/test_utils/ad.jl +++ b/src/test_utils/ad.jl @@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff using Chairmarks: @be import DifferentiationInterface as DI using DocStringExtensions -using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo +using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link using LogDensityProblems: logdensity, logdensity_and_gradient using Random: Random, Xoshiro using Statistics: median using Test: @test -export ADResult, run_ad - -# This function needed to work around the fact that different backends can -# return different AbstractArrays for the gradient. See -# https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more -# context. -_to_vec_f64(x::AbstractArray) = x isa Vector{Float64} ? x : collect(Float64, x) +export ADResult, run_ad, ADIncorrectException """ REFERENCE_ADTYPE @@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl. const REFERENCE_ADTYPE = AutoForwardDiff() """ - ADResult + ADIncorrectException{T<:AbstractFloat} + +Exception thrown when an AD backend returns an incorrect value or gradient. + +The type parameter `T` is the numeric type of the value and gradient. +""" +struct ADIncorrectException{T<:AbstractFloat} <: Exception + value_expected::T + value_actual::T + grad_expected::Vector{T} + grad_actual::Vector{T} +end + +""" + ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} Data structure to store the results of the AD correctness test. + +The type parameter `Tparams` is the numeric type of the parameters passed in; +`Tresult` is the type of the value and the gradient. """ -struct ADResult +struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat} "The DynamicPPL model that was tested" model::Model "The VarInfo that was used" varinfo::AbstractVarInfo "The values at which the model was evaluated" - params::Vector{<:Real} + params::Vector{Tparams} "The AD backend that was tested" adtype::AbstractADType "The absolute tolerance for the value of logp" - value_atol::Real + value_atol::Tresult "The absolute tolerance for the gradient of logp" - grad_atol::Real + grad_atol::Tresult "The expected value of logp" - value_expected::Union{Nothing,Float64} + value_expected::Union{Nothing,Tresult} "The expected gradient of logp" - grad_expected::Union{Nothing,Vector{Float64}} + grad_expected::Union{Nothing,Vector{Tresult}} "The value of logp (calculated using `adtype`)" - value_actual::Union{Nothing,Real} + value_actual::Union{Nothing,Tresult} "The gradient of logp (calculated using `adtype`)" - grad_actual::Union{Nothing,Vector{Float64}} + grad_actual::Union{Nothing,Vector{Tresult}} "If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself" - time_vs_primal::Union{Nothing,Float64} + time_vs_primal::Union{Nothing,Tresult} end """ @@ -64,26 +75,27 @@ end benchmark=false, value_atol=1e-6, grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult +### Description + Test the correctness and/or benchmark the AD backend `adtype` for the model `model`. Whether to test and benchmark is controlled by the `test` and `benchmark` keyword arguments. By default, `test` is `true` and `benchmark` is `false`. -Returns an [`ADResult`](@ref) object, which contains the results of the -test and/or benchmark. - Note that to run AD successfully you will need to import the AD backend itself. For example, to test with `AutoReverseDiff()` you will need to run `import ReverseDiff`. +### Arguments + There are two positional arguments, which absolutely must be provided: 1. `model` - The model being tested. @@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups: DynamicPPL contains several different types of VarInfo objects which change the way model evaluation occurs. If you want to use a specific type of VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to - using a `TypedVarInfo` generated from the model. + using a linked `TypedVarInfo` generated from the model. Here, _linked_ + means that the parameters in the VarInfo have been transformed to + unconstrained Euclidean space if they aren't already in that space. 2. _How to specify the parameters._ @@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups: By default, this function prints messages when it runs. To silence it, set `verbose=false`. + +### Returns / Throws + +Returns an [`ADResult`](@ref) object, which contains the results of the +test and/or benchmark. + +If `test` is `true` and the AD backend returns an incorrect value or gradient, an +`ADIncorrectException` is thrown. If a different error occurs, it will be +thrown as-is. """ function run_ad( model::Model, adtype::AbstractADType; - test=true, - benchmark=false, - value_atol=1e-6, - grad_atol=1e-6, - varinfo::AbstractVarInfo=VarInfo(model), - params::Vector{<:Real}=varinfo[:], + test::Bool=true, + benchmark::Bool=false, + value_atol::AbstractFloat=1e-6, + grad_atol::AbstractFloat=1e-6, + varinfo::AbstractVarInfo=link(VarInfo(model), model), + params::Union{Nothing,Vector{<:AbstractFloat}}=nothing, reference_adtype::AbstractADType=REFERENCE_ADTYPE, - expected_value_and_grad::Union{Nothing,Tuple{Real,Vector{<:Real}}}=nothing, + expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing, verbose=true, )::ADResult + if isnothing(params) + params = varinfo[:] + end + params = map(identity, params) # Concretise + verbose && @info "Running AD on $(model.f) with $(adtype)\n" - params = map(identity, params) verbose && println(" params : $(params)") ldf = LogDensityFunction(model, varinfo; adtype=adtype) value, grad = logdensity_and_gradient(ldf, params) - grad = _to_vec_f64(grad) + grad = collect(grad) verbose && println(" actual : $((value, grad))") if test @@ -172,10 +199,11 @@ function run_ad( expected_value_and_grad end verbose && println(" expected : $((value_true, grad_true))") - grad_true = _to_vec_f64(grad_true) - # Then compare - @test isapprox(value, value_true; atol=value_atol) - @test isapprox(grad, grad_true; atol=grad_atol) + grad_true = collect(grad_true) + + exc() = throw(ADIncorrectException(value, value_true, grad, grad_true)) + isapprox(value, value_true; atol=value_atol) || exc() + isapprox(grad, grad_true; atol=grad_atol) || exc() else value_true = nothing grad_true = nothing diff --git a/src/transforming.jl b/src/transforming.jl index 0239725ae..429562ec8 100644 --- a/src/transforming.jl +++ b/src/transforming.jl @@ -19,9 +19,9 @@ function tilde_assume( lp = Bijectors.logpdf_with_trans(right, r, !isinverse) if istrans(vi, vn) - @assert isinverse "Trying to link already transformed variables" + isinverse || @warn "Trying to link an already transformed variable ($vn)" else - @assert !isinverse "Trying to invlink non-transformed variables" + isinverse && @warn "Trying to invlink a non-transformed variable ($vn)" end # Only transform if `!isinverse` since `vi[vn, right]` diff --git a/test/ad.jl b/test/ad.jl index 33d581228..69ab99e19 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -23,21 +23,23 @@ using DynamicPPL: LogDensityFunction varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos - f = LogDensityFunction(m, varinfo) + linked_varinfo = DynamicPPL.link(varinfo, m) + f = LogDensityFunction(m, linked_varinfo) x = DynamicPPL.getparams(f) # Calculate reference logp + gradient of logp using ForwardDiff - ref_ldf = LogDensityFunction(m, varinfo; adtype=ref_adtype) + ref_ldf = LogDensityFunction(m, linked_varinfo; adtype=ref_adtype) ref_logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ref_ldf, x) @testset "$adtype" for adtype in test_adtypes - @info "Testing AD on: $(m.f) - $(short_varinfo_name(varinfo)) - $adtype" + @info "Testing AD on: $(m.f) - $(short_varinfo_name(linked_varinfo)) - $adtype" # Put predicates here to avoid long lines is_mooncake = adtype isa AutoMooncake is_1_10 = v"1.10" <= VERSION < v"1.11" is_1_11 = v"1.11" <= VERSION < v"1.12" - is_svi_vnv = varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} - is_svi_od = varinfo isa SimpleVarInfo{<:OrderedDict} + is_svi_vnv = + linked_varinfo isa SimpleVarInfo{<:DynamicPPL.VarNamedVector} + is_svi_od = linked_varinfo isa SimpleVarInfo{<:OrderedDict} # Mooncake doesn't work with several combinations of SimpleVarInfo. if is_mooncake && is_1_11 && is_svi_vnv @@ -56,12 +58,12 @@ using DynamicPPL: LogDensityFunction ref_ldf, adtype ) else - DynamicPPL.TestUtils.AD.run_ad( + @test DynamicPPL.TestUtils.AD.run_ad( m, adtype; - varinfo=varinfo, + varinfo=linked_varinfo, expected_value_and_grad=(ref_logp, ref_grad), - ) + ) isa Any end end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index aa3b592f7..380c24e7d 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -111,12 +111,6 @@ # Should be approx. the same as the "lazy" transformation. @test logjoint(model, vi_linked) ≈ lp_linked - # TODO: Should not `VarInfo` also error here? The current implementation - # only warns and acts as a no-op. - if vi isa SimpleVarInfo - @test_throws AssertionError link!!(vi_linked, model) - end - # `invlink!!` vi_invlinked = invlink!!(deepcopy(vi_linked), model) lp_invlinked = getlogp(vi_invlinked)