@@ -3,6 +3,7 @@ Pkg.develop(; path=joinpath(@__DIR__, ".."))
3
3
4
4
import Test: @test , @testset
5
5
import ModelTests: MODELS, run_ad, ADIncorrectException
6
+ import DynamicPPL as D
6
7
using ADTypes
7
8
using Printf: @printf
8
9
@@ -21,10 +22,10 @@ NOTE: Make sure that the names are unique and do not contain commas
21
22
ADTYPES = Dict (
22
23
" ForwardDiff" => AutoForwardDiff (),
23
24
" ReverseDiff" => AutoReverseDiff (; compile= false ),
24
- " ReverseDiff:Compiled " => AutoReverseDiff (; compile= true ),
25
+ " ReverseDiffCompiled " => AutoReverseDiff (; compile= true ),
25
26
" Mooncake" => AutoMooncake (; config= nothing ),
26
- " Enzyme:Forward " => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
27
- " Enzyme:Reverse " => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
27
+ " EnzymeForward " => AutoEnzyme (; mode= set_runtime_activity (Forward, true )),
28
+ " EnzymeReverse " => AutoEnzyme (; mode= set_runtime_activity (Reverse, true )),
28
29
)
29
30
30
31
if ARGS == [" --list-model-keys" ]
@@ -33,7 +34,17 @@ elseif ARGS == ["--list-adtype-keys"]
33
34
foreach (println, keys (ADTYPES))
34
35
elseif length (ARGS ) == 3 && ARGS [1 ] == " --run"
35
36
model, adtype = MODELS[ARGS [2 ]], ADTYPES[ARGS [3 ]]
36
- result = run_ad (model, adtype; benchmark= true )
37
+
38
+
39
+ if ARGS [2 ] == " control_flow"
40
+ # https://github.com/penelopeysm/ModelTests.jl/issues/4
41
+ vi = D. unflatten (D. VarInfo (model), [0.5 , - 0.5 ])
42
+ params = [- 0.5 , 0.5 ]
43
+ result = run_ad (model, adtype; varinfo= vi, params= params, benchmark= true )
44
+ else
45
+ result = run_ad (model, adtype; benchmark= true )
46
+ end
47
+
37
48
if isnothing (result. error)
38
49
@printf (" %.3f" , result. time_vs_primal)
39
50
elseif result. error isa ADIncorrectException
0 commit comments