Skip to content

Allow specifying context in AD testing #935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# DynamicPPL Changelog

## 0.36.6

`DynamicPPL.TestUtils.run_ad` now takes an extra `context` keyword argument, which is passed to the `LogDensityFunction` constructor.

## 0.36.5

`varinfo[:]` now returns an empty vector if `varinfo::DynamicPPL.NTVarInfo` is empty, rather than erroring.
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.36.5"
version = "0.36.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
30 changes: 24 additions & 6 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
using DynamicPPL:
Model,
LogDensityFunction,
VarInfo,
AbstractVarInfo,
link,
DefaultContext,
AbstractContext
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
Expand Down Expand Up @@ -53,6 +60,8 @@
model::Model
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The evaluation context that was used"
context::AbstractContext
"The values at which the model was evaluated"
params::Vector{Tparams}
"The AD backend that was tested"
Expand Down Expand Up @@ -83,6 +92,7 @@
grad_atol=1e-6,
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
context::AbstractContext=DefaultContext(),
reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
Expand Down Expand Up @@ -136,7 +146,13 @@
prep_params)`. You could then evaluate the gradient at a different set of
parameters using the `params` keyword argument.

3. _How to specify the results to compare against._ (Only if `test=true`.)
3. _How to specify the evaluation context._

A `DynamicPPL.AbstractContext` can be passed as the `context` keyword
argument to control the evaluation context. This defaults to
`DefaultContext()`.

4. _How to specify the results to compare against._ (Only if `test=true`.)

Once logp and its gradient has been calculated with the specified `adtype`,
it must be tested for correctness.
Expand All @@ -151,12 +167,12 @@
The default reference backend is ForwardDiff. If none of these parameters are
specified, ForwardDiff will be used to calculate the ground truth.

4. _How to specify the tolerances._ (Only if `test=true`.)
5. _How to specify the tolerances._ (Only if `test=true`.)

The tolerances for the value and gradient can be set using `value_atol` and
`grad_atol`. These default to 1e-6.

5. _Whether to output extra logging information._
6. _Whether to output extra logging information._

By default, this function prints messages when it runs. To silence it, set
`verbose=false`.
Expand All @@ -179,6 +195,7 @@
grad_atol::AbstractFloat=1e-6,
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
context::AbstractContext=DefaultContext(),
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat,Vector{<:AbstractFloat}}}=nothing,
verbose=true,
Expand All @@ -190,7 +207,7 @@

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
ldf = LogDensityFunction(model, varinfo, context; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = collect(grad)
Expand All @@ -199,7 +216,7 @@
if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
ldf_reference = LogDensityFunction(model, varinfo, context; adtype=reference_adtype)

Check warning on line 219 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L219

Added line #L219 was not covered by tests
logdensity_and_gradient(ldf_reference, params)
else
expected_value_and_grad
Expand Down Expand Up @@ -228,6 +245,7 @@
return ADResult(
model,
varinfo,
context,
params,
adtype,
value_atol,
Expand Down
Loading