Skip to content

Commit 1874554

Browse files
committed
Fix control_flow (?)
1 parent a5a0691 commit 1874554

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

app/Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[deps]
22
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
34
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
45
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
56
ModelTests = "d212b3af-bb80-4029-9a81-ee0a391ae514"

app/output.jl

+15-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Pkg.develop(; path=joinpath(@__DIR__, ".."))
33

44
import Test: @test, @testset
55
import ModelTests: MODELS, run_ad, ADIncorrectException
6+
import DynamicPPL as D
67
using ADTypes
78
using Printf: @printf
89

@@ -21,10 +22,10 @@ NOTE: Make sure that the names are unique and do not contain commas
2122
ADTYPES = Dict(
2223
"ForwardDiff" => AutoForwardDiff(),
2324
"ReverseDiff" => AutoReverseDiff(; compile=false),
24-
"ReverseDiff:Compiled" => AutoReverseDiff(; compile=true),
25+
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
2526
"Mooncake" => AutoMooncake(; config=nothing),
26-
"Enzyme:Forward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
27-
"Enzyme:Reverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
27+
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
28+
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
2829
)
2930

3031
if ARGS == ["--list-model-keys"]
@@ -33,7 +34,17 @@ elseif ARGS == ["--list-adtype-keys"]
3334
foreach(println, keys(ADTYPES))
3435
elseif length(ARGS) == 3 && ARGS[1] == "--run"
3536
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
36-
result = run_ad(model, adtype; benchmark=true)
37+
38+
39+
if ARGS[2] == "control_flow"
40+
# https://github.com/penelopeysm/ModelTests.jl/issues/4
41+
vi = D.unflatten(D.VarInfo(model), [0.5, -0.5])
42+
params = [-0.5, 0.5]
43+
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
44+
else
45+
result = run_ad(model, adtype; benchmark=true)
46+
end
47+
3748
if isnothing(result.error)
3849
@printf("%.3f", result.time_vs_primal)
3950
elseif result.error isa ADIncorrectException

0 commit comments

Comments
 (0)