|
| 1 | +module AD |
| 2 | + |
| 3 | +import ADTypes: AbstractADType |
| 4 | +import DifferentiationInterface as DI |
| 5 | +import ..DynamicPPL: DynamicPPL, Model, LogDensityFunction, VarInfo, AbstractVarInfo |
| 6 | +import LogDensityProblems: logdensity, logdensity_and_gradient |
| 7 | +import LogDensityProblemsAD: ADgradient |
| 8 | +import Random: Random, AbstractRNG |
| 9 | +import Test: @test |
| 10 | + |
| 11 | +export make_function, make_params, ad_ldp, ad_di, test_correctness |
| 12 | + |
| 13 | +""" |
| 14 | + flipped_logdensity(θ, ldf) |
| 15 | +
|
| 16 | +Flips the order of arguments for `logdensity` to match the signature needed |
| 17 | +for DifferentiationInterface.jl. |
| 18 | +""" |
| 19 | +flipped_logdensity(θ, ldf) = logdensity(ldf, θ) |
| 20 | + |
| 21 | +""" |
| 22 | + ad_ldp( |
| 23 | + model::Model, |
| 24 | + params::Vector{<:Real}, |
| 25 | + adtype::AbstractADType, |
| 26 | + varinfo::AbstractVarInfo=VarInfo(model) |
| 27 | + ) |
| 28 | +
|
| 29 | +Calculate the logdensity of `model` and its gradient using the AD backend |
| 30 | +`adtype`, evaluated at the parameters `params`, using the implementation of |
| 31 | +`logdensity_and_gradient` in the LogDensityProblemsAD.jl package. |
| 32 | +
|
| 33 | +The `varinfo` argument is optional and is used to provide the container |
| 34 | +structure for the parameters. Note that the _parameters_ inside the `varinfo` |
| 35 | +argument itself are overridden by the `params` argument. This argument defaults |
| 36 | +to [`DynamicPPL.VarInfo`](@ref), which is the default container structure used |
| 37 | +throughout the Turing ecosystem; however, you can provide e.g. |
| 38 | +[`DynamicPPL.SimpleVarInfo`](@ref) if you want to use a different container |
| 39 | +structure. |
| 40 | +
|
| 41 | +Returns a tuple `(value, gradient)` where `value <: Real` is the logdensity |
| 42 | +of the model evaluated at `params`, and `gradient <: Vector{<:Real}` is the |
| 43 | +gradient of the logdensity with respect to `params`. |
| 44 | +
|
| 45 | +Note that DynamicPPL.jl and Turing.jl currently use LogDensityProblemsAD.jl |
| 46 | +throughout, and hence this function most closely mimics the usage of AD within |
| 47 | +the Turing ecosystem. |
| 48 | +
|
| 49 | +For some AD backends such as Mooncake.jl, LogDensityProblemsAD.jl simply defers |
| 50 | +to the DifferentiationInterface.jl package. In such a case, `ad_ldp` simplifies |
| 51 | +to `ad_di` (in that if `ad_di` passes, one should expect `ad_ldp` to pass as |
| 52 | +well). |
| 53 | +
|
| 54 | +However, there are other AD backends which still have custom code in |
| 55 | +LogDensityProblemsAD.jl (such as ForwardDiff.jl). For these backends, `ad_di` |
| 56 | +may yield different results compared to `ad_ldp`, and the behaviour of `ad_di` |
| 57 | +is in such cases not guaranteed to be consistent with the behaviour of |
| 58 | +Turing.jl. |
| 59 | +
|
| 60 | +See also: [`ad_di`](@ref). |
| 61 | +""" |
| 62 | +function ad_ldp( |
| 63 | + model::Model, |
| 64 | + params::Vector{<:Real}, |
| 65 | + adtype::AbstractADType, |
| 66 | + vi::AbstractVarInfo=VarInfo(model), |
| 67 | +) |
| 68 | + ldf = LogDensityFunction(model, vi) |
| 69 | + # Note that the implementation of logdensity takes care of setting the |
| 70 | + # parameters in vi to the correct values (using unflatten) |
| 71 | + return logdensity_and_gradient(ADgradient(adtype, ldf), params) |
| 72 | +end |
| 73 | + |
| 74 | +""" |
| 75 | + ad_di( |
| 76 | + model::Model, |
| 77 | + params::Vector{<:Real}, |
| 78 | + adtype::AbstractADType, |
| 79 | + varinfo::AbstractVarInfo=VarInfo(model) |
| 80 | + ) |
| 81 | +
|
| 82 | +Calculate the logdensity of `model` and its gradient using the AD backend |
| 83 | +`adtype`, evaluated at the parameters `params`, directly using |
| 84 | +DifferentiationInterface.jl. |
| 85 | +
|
| 86 | +See the notes in [`ad_ldp`](@ref) for more details on the differences between |
| 87 | +`ad_di` and `ad_ldp`. |
| 88 | +""" |
| 89 | +function ad_di( |
| 90 | + model::Model, |
| 91 | + params::Vector{<:Real}, |
| 92 | + adtype::AbstractADType, |
| 93 | + vi::AbstractVarInfo=VarInfo(model), |
| 94 | +) |
| 95 | + ldf = LogDensityFunction(model, vi) |
| 96 | + # Note that the implementation of logdensity takes care of setting the |
| 97 | + # parameters in vi to the correct values (using unflatten) |
| 98 | + prep = DI.prepare_gradient(flipped_logdensity, adtype, params, DI.Constant(ldf)) |
| 99 | + return DI.value_and_gradient(flipped_logdensity, prep, adtype, params, DI.Constant(ldf)) |
| 100 | +end |
| 101 | + |
| 102 | +""" |
| 103 | + make_function(model, varinfo::AbstractVarInfo=VarInfo(model)) |
| 104 | +
|
| 105 | +Generate the function to be differentiated. Specifically, |
| 106 | +`make_function(model)` returns a function which takes a single argument |
| 107 | +`params` and returns the logdensity of `model` evaluated at `params`. |
| 108 | +
|
| 109 | +The `varinfo` parameter is optional and is used to determine the structure of |
| 110 | +the varinfo used during evaluation. See the [`ad_ldp`](@ref) function for more |
| 111 | +details on the `varinfo` argument. |
| 112 | +
|
| 113 | +If you have an AD package that does not have integrations with either |
| 114 | +LogDensityProblemsAD.jl (in which case you can use [`ad_ldp`](@ref)) or |
| 115 | +DifferentiationInterface.jl (in which case you can use [`ad_di`](@ref)), you |
| 116 | +can test whether your AD package works with Turing.jl models using: |
| 117 | +
|
| 118 | +```julia |
| 119 | +f = make_function(model) |
| 120 | +params = make_params(model) |
| 121 | +value, grad = YourADPackage.gradient(f, params) |
| 122 | +``` |
| 123 | +
|
| 124 | +and compare the results against that obtained from either `ad_ldp` or `ad_di` for |
| 125 | +an existing AD package that _is_ supported. |
| 126 | +
|
| 127 | +See also: [`make_params`](@ref). |
| 128 | +""" |
| 129 | +function make_function(model::Model, vi::AbstractVarInfo=VarInfo(model)) |
| 130 | + # TODO: Can we simplify this even further by inlining the definition of |
| 131 | + # logdensity? |
| 132 | + return Base.Fix1(logdensity, LogDensityFunction(model, vi)) |
| 133 | +end |
| 134 | + |
| 135 | +""" |
| 136 | + make_params(model, rng::Random.AbstractRNG=Random.default_rng()) |
| 137 | +
|
| 138 | +Generate a vector of parameters sampled from the prior distribution of `model`. |
| 139 | +This can be used as the input to the function to be differentiated. See |
| 140 | +[`make_function`](@ref) for more details. |
| 141 | +""" |
| 142 | +function make_params(model::Model, rng::AbstractRNG=Random.default_rng()) |
| 143 | + return VarInfo(rng, model)[:] |
| 144 | +end |
| 145 | + |
| 146 | +""" |
| 147 | + test_correctness( |
| 148 | + ad_function, |
| 149 | + model::Model, |
| 150 | + adtypes::Vector{<:ADTypes.AbstractADType}, |
| 151 | + reference_adtype::ADTypes.AbstractADType, |
| 152 | + rng::Random.AbstractRNG=Random.default_rng(), |
| 153 | + params::Vector{<:Real}=VarInfo(rng, model)[:]; |
| 154 | + value_atol=1e-6, |
| 155 | + grad_atol=1e-6 |
| 156 | + ) |
| 157 | +
|
| 158 | +Test the correctness of all the AD backend `adtypes` for the model `model` |
| 159 | +using the implementation `ad_function`. `ad_function` should be either |
| 160 | +[`ad_ldp`](@ref) or [`ad_di`](@ref), or a custom function that has the same |
| 161 | +signature. |
| 162 | +
|
| 163 | +The test is performed by calculating the logdensity and its gradient using all |
| 164 | +the AD backends, and comparing the results against that obtained with the |
| 165 | +reference AD backend `reference_adtype`. |
| 166 | +
|
| 167 | +The parameters can either be passed explicitly using the `params` argument, or can |
| 168 | +be sampled from the prior distribution of the model using the `rng` argument. |
| 169 | +""" |
| 170 | +function test_correctness( |
| 171 | + ad_function, |
| 172 | + model::Model, |
| 173 | + adtypes::Vector{<:AbstractADType}, |
| 174 | + reference_adtype::AbstractADType, |
| 175 | + rng::AbstractRNG=Random.default_rng(), |
| 176 | + params::Vector{<:Real}=VarInfo(rng, model)[:]; |
| 177 | + value_atol=1e-6, |
| 178 | + grad_atol=1e-6, |
| 179 | +) |
| 180 | + value_true, grad_true = ad_function(model, params, reference_adtype) |
| 181 | + for adtype in adtypes |
| 182 | + value, grad = ad_function(model, params, adtype) |
| 183 | + info_str = join( |
| 184 | + [ |
| 185 | + "Testing AD correctness", |
| 186 | + " AD function : $(ad_function)", |
| 187 | + " backend : $(adtype)", |
| 188 | + " model : $(model.f)", |
| 189 | + " params : $(params)", |
| 190 | + " actual : $((value, grad))", |
| 191 | + " expected : $((value_true, grad_true))", |
| 192 | + ], |
| 193 | + "\n", |
| 194 | + ) |
| 195 | + @info info_str |
| 196 | + @test value ≈ value_true atol = value_atol |
| 197 | + @test grad ≈ grad_true atol = grad_atol |
| 198 | + end |
| 199 | +end |
| 200 | + |
| 201 | +end # module DynamicPPL.TestUtils.AD |
0 commit comments