Skip to content

Implement AD testing and benchmarking (with DITest) #883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from

Conversation

penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Apr 4, 2025

Part 2 of two options. The other one at #882.

Closes #869

Why am I not in favour of this one?

I think some exposition is required here, and I didn't have time to explain this super clearly during the meeting.

The API of DITest is like this:

  1. You construct a scenario, which includes the function f, the value at which to evaluate it / the gradient x, and a bunch of other things. Crucially, the scenario does not include the adtype.

  2. You then run the scenario with an adtype (or an array thereof).

From the perspective of generic functions f, this is quite a nice interface. The tricky bit with DynamicPPL, as I briefly mentioned, is that when you pass LogDensityFunction a model, varinfo, etc. it does a bunch of things that not only changes the function f being differentiated, but also potentially modifies the adtype that is actually used. See, especially, this constructor:

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=leafcontext(model.context);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
.

(Note that LogDensityFunctionsAD.jl used to do this stuff for us; #806 effectively removed it and inlined its optimisations into that inner constructor.)

What this means is that, to be completely consistent with the way DynamicPPL behaves, one has to:

  1. Reproduce the code inside src/logdensityfunctions.jl that generates the function f, so that the scenario can use the correct f.
  2. Because the above depends on the adtype, we have to make sure that scenarios generated with one adtype are later run with the same adtype.
    • In fact, the preparation in the LogDensityFunction doesn't only depend on the adtype; it potentially also modifies the adtype.
    • That's why this PR doesn't just include make_scenario; it also includes a run_ad function below, which ensures that the scenario is run with the appropriately modified adtype.

If we adopt this PR, then we have to choose between either:

  1. Duplicating the code inside src/logdensityfunctions.jl, as I've done in this PR; or
  2. Cutting this duplicated code out, which means that the results obtained when using this test/benchmark function will differ from the results when actually sampling a Turing model;
  3. Removing the extra prep work inside src/logdensityfunctions.jl

(3) is a no-go as it would have noticeable impacts on performance, and even though I think it'd be very nice if we could just export a list of scenarios, I'm not really comfortable with either (1) or (2), and I don't think it's a good enough reason to do either.

The alternative to this, #882, already makes the API very straightforward (it's just one function with a very thorough docstring) and so I don't think it's unfair to define that as our interface - especially considering that it's most likely that we will actually be the ones writing the integration tests for other people.

@penelopeysm penelopeysm changed the title Implement AD testing (with DITest) Implement AD testing and benchmarking (with DITest) Apr 4, 2025
Copy link
Contributor

github-actions bot commented Apr 4, 2025

Benchmark Report for Commit e1a34e1

Computer Information

Julia Version 1.11.4
Commit 8561cc3d68d (2025-03-10 11:36 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

|                 Model | Dimension |  AD Backend |      VarInfo Type | Linked | Eval Time / Ref Time | AD Time / Eval Time |
|-----------------------|-----------|-------------|-------------------|--------|----------------------|---------------------|
| Simple assume observe |         1 | forwarddiff |             typed |  false |                  9.9 |                 1.5 |
|           Smorgasbord |       201 | forwarddiff |             typed |  false |                617.9 |                42.6 |
|           Smorgasbord |       201 | forwarddiff | simple_namedtuple |   true |                419.8 |                48.4 |
|           Smorgasbord |       201 | forwarddiff |           untyped |   true |               1243.3 |                27.5 |
|           Smorgasbord |       201 | forwarddiff |       simple_dict |   true |               3937.0 |                20.4 |
|           Smorgasbord |       201 | reversediff |             typed |   true |               1459.6 |                29.8 |
|           Smorgasbord |       201 |    mooncake |             typed |   true |                944.3 |                 5.4 |
|    Loop univariate 1k |      1000 |    mooncake |             typed |   true |               5567.2 |                 4.1 |
|       Multivariate 1k |      1000 |    mooncake |             typed |   true |               1123.4 |                 8.2 |
|   Loop univariate 10k |     10000 |    mooncake |             typed |   true |              61969.3 |                 3.7 |
|      Multivariate 10k |     10000 |    mooncake |             typed |   true |               8946.0 |                 9.6 |
|               Dynamic |        10 |    mooncake |             typed |   true |                136.5 |                11.9 |
|              Submodel |         1 |    mooncake |             typed |   true |                 25.7 |                 7.7 |
|                   LDA |        12 | reversediff |             typed |   true |                479.8 |                 5.2 |

Copy link

codecov bot commented Apr 4, 2025

Codecov Report

Attention: Patch coverage is 88.88889% with 2 lines in your changes missing coverage. Please review.

Project coverage is 84.89%. Comparing base (eed80e5) to head (e1a34e1).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 88.88% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #883      +/-   ##
==========================================
+ Coverage   84.87%   84.89%   +0.01%     
==========================================
  Files          34       35       +1     
  Lines        3815     3833      +18     
==========================================
+ Hits         3238     3254      +16     
- Misses        577      579       +2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@coveralls
Copy link

coveralls commented Apr 4, 2025

Pull Request Test Coverage Report for Build 14256574630

Details

  • 0 of 14 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-3.5%) to 81.418%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 0 14 0.0%
Totals Coverage Status
Change from base Build 14127923718: -3.5%
Covered Lines: 3111
Relevant Lines: 3821

💛 - Coveralls

@coveralls
Copy link

coveralls commented Apr 4, 2025

Pull Request Test Coverage Report for Build 14263072728

Details

  • 0 of 18 (0.0%) changed or added relevant lines in 1 file are covered.
  • 20 unchanged lines in 3 files lost coverage.
  • Overall coverage increased (+0.02%) to 84.983%

Changes Missing Coverage Covered Lines Changed/Added Lines %
ext/DynamicPPLDifferentiationInterfaceTestExt.jl 0 18 0.0%
Files with Coverage Reduction New Missed Lines %
src/model.jl 1 85.83%
src/varinfo.jl 3 84.51%
src/threadsafe.jl 16 55.05%
Totals Coverage Status
Change from base Build 14127923718: 0.02%
Covered Lines: 3254
Relevant Lines: 3829

💛 - Coveralls

@sunxd3
Copy link
Member

sunxd3 commented Apr 8, 2025

The reasons for preference are super valid. I also think that since the hand-rolled version is not too complicated, it's worth to maintain it ourselves. Otherwise for new contributors to be able to contribute to this, they need to know what a test scenario is for DIT.

@penelopeysm penelopeysm mentioned this pull request Apr 8, 2025
8 tasks
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The amount of (duplicated) code is quite minimal. It might make sense to merge this PR together with #882

)
end

function DynamicPPL.TestUtils.AD.run_ad(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One alternative is to overload DI.test_differentiation so we can tweak adtype internally.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIT.Scenario doesn't contain a type parameter that we can dispatch on, though.

(While putting the ADTests bit together, I also found out that Scenarios can't be prepared with a specific value of x: JuliaDiff/DifferentiationInterface.jl#771)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could do something like test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...) and have that construct the scenario

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could do something like test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...) and have that construct the scenario

I like this idea. One could have both

  • test_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...)
  • benchmark_differentiation(..., ::DynamicPPL.Model, ::AbstractADType, ...)

Having unified interfaces across DI and Turing for these autodiff tests would be nice.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please keep in mind that DIT was designed mostly to test DI itself, so its interface is still rather dirty and unstable. Also, DIT.test_differentiation does way more than you probably need here. But if there is interest in standardization, we can take a look

Copy link
Member

@yebai yebai Apr 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @gdalle, for the context.

It is an excellent idea to improve DIT so it can become a community resource like DI. It's very helpful to have a standard interface where

  • packages (like DynamicPPL, Bijectors, Distributions) can register test scenerios for AD backends
  • AD backends can run registered test scenarios easily

It would help the autodiff dev community discover bugs more quickly. It would also inform the general users which AD backend is likely compatible with the library (e.g. Lux, Turing) they want to use (see, e.g. https://turinglang.org/ADTests/)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DIT is in the weird position where it simultaneously does much more than what we need and also doesn't do some of the things we need. I've said this elsewhere (in meetings etc) but this isn't a criticism of DIT, it's just about choosing the right tool for the job IMO.

@penelopeysm
Copy link
Member Author

Closing for now (with a note to keep an eye on DIT if we can use it for future AD testing stuff)

@yebai yebai deleted the py/adtype2 branch April 17, 2025 09:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

AD testing
5 participants