Skip to content

Commit 0d0d87b

Browse files
committed
Test Enzyme
1 parent e4fa7f2 commit 0d0d87b

File tree

2 files changed

+168
-164
lines changed

2 files changed

+168
-164
lines changed

test/ad.jl

+119-115
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
11
using DynamicPPL: LogDensityFunction
2+
using EnzymeCore: set_runtime_activity, Forward, Reverse
23

34
@testset "Automatic differentiation" begin
45
# Used as the ground truth that others are compared against.
56
ref_adtype = AutoForwardDiff()
67
test_adtypes = [
7-
AutoReverseDiff(; compile=false),
8-
AutoReverseDiff(; compile=true),
9-
AutoMooncake(; config=nothing),
8+
# AutoReverseDiff(; compile=false),
9+
# AutoReverseDiff(; compile=true),
10+
# AutoMooncake(; config=nothing),
11+
AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
12+
AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
1013
]
1114

12-
@testset "Unsupported backends" begin
13-
@model demo() = x ~ Normal()
14-
@test_logs (:warn, r"not officially supported") LogDensityFunction(
15-
demo(); adtype=AutoZygote()
16-
)
17-
end
15+
# @testset "Unsupported backends" begin
16+
# @model demo() = x ~ Normal()
17+
# @test_logs (:warn, r"not officially supported") LogDensityFunction(
18+
# demo(); adtype=AutoZygote()
19+
# )
20+
# end
1821

19-
@testset "Correctness: ForwardDiff, ReverseDiff, and Mooncake" begin
22+
@testset "Correctness on supported AD backends" begin
2023
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
21-
rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
22-
vns = DynamicPPL.TestUtils.varnames(m)
23-
varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
24+
# rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m)
25+
# vns = DynamicPPL.TestUtils.varnames(m)
26+
# varinfos = DynamicPPL.TestUtils.setup_varinfos(m, rand_param_values, vns)
27+
varinfos = [VarInfo(m)]
2428

2529
@testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos
2630
f = LogDensityFunction(m, varinfo)
@@ -66,106 +70,106 @@ using DynamicPPL: LogDensityFunction
6670
end
6771
end
6872

69-
@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
70-
# Failing model
71-
t = 1:0.05:8
72-
σ = 0.3
73-
y = @. rand(sin(t) + Normal(0, σ))
74-
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
75-
# Priors
76-
α ~ Normal(y[1], 0.001)
77-
τ ~ Exponential(1)
78-
η ~ filldist(Normal(0, 1), TT - 1)
79-
σ ~ Exponential(1)
80-
# create latent variable
81-
x = Vector{T}(undef, TT)
82-
x[1] = α
83-
for t in 2:TT
84-
x[t] = x[t - 1] + η[t - 1] * τ
85-
end
86-
# measurement model
87-
y ~ MvNormal(x, σ^2 * I)
88-
return x
89-
end
90-
model = state_space(y, length(t))
91-
92-
# Dummy sampling algorithm for testing. The test case can only be replicated
93-
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
94-
# overload assume so that model evaluation doesn't fail due to a lack
95-
# of implementation
96-
struct MyEmptyAlg end
97-
DynamicPPL.assume(
98-
::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
99-
) = DynamicPPL.assume(dist, vn, vi)
100-
101-
# Compiling the ReverseDiff tape used to fail here
102-
spl = Sampler(MyEmptyAlg())
103-
vi = VarInfo(model)
104-
ldf = LogDensityFunction(
105-
model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
106-
)
107-
@test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
108-
end
109-
110-
# Test that various different ways of specifying array types as arguments work with all
111-
# ADTypes.
112-
@testset "Array argument types" begin
113-
test_m = randn(2, 3)
114-
115-
function eval_logp_and_grad(model, m, adtype)
116-
ldf = LogDensityFunction(model(); adtype=adtype)
117-
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
118-
end
119-
120-
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
121-
m = Matrix{T}(undef, 2, 3)
122-
return m ~ filldist(MvNormal(zeros(2), I), 3)
123-
end
124-
125-
scalar_matrix_model_reference = eval_logp_and_grad(
126-
scalar_matrix_model, test_m, ref_adtype
127-
)
128-
129-
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
130-
m = T(undef, 2, 3)
131-
return m ~ filldist(MvNormal(zeros(2), I), 3)
132-
end
133-
134-
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
135-
136-
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
137-
m = Array{T}(undef, 2, 3)
138-
return m ~ filldist(MvNormal(zeros(2), I), 3)
139-
end
140-
141-
scalar_array_model_reference = eval_logp_and_grad(
142-
scalar_array_model, test_m, ref_adtype
143-
)
144-
145-
@model function array_model(::Type{T}=Array{Float64}) where {T}
146-
m = T(undef, 2, 3)
147-
return m ~ filldist(MvNormal(zeros(2), I), 3)
148-
end
149-
150-
array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype)
151-
152-
@testset "$adtype" for adtype in test_adtypes
153-
scalar_matrix_model_logp_and_grad = eval_logp_and_grad(
154-
scalar_matrix_model, test_m, adtype
155-
)
156-
@test scalar_matrix_model_logp_and_grad[1] scalar_matrix_model_reference[1]
157-
@test scalar_matrix_model_logp_and_grad[2] scalar_matrix_model_reference[2]
158-
matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype)
159-
@test matrix_model_logp_and_grad[1] matrix_model_reference[1]
160-
@test matrix_model_logp_and_grad[2] matrix_model_reference[2]
161-
scalar_array_model_logp_and_grad = eval_logp_and_grad(
162-
scalar_array_model, test_m, adtype
163-
)
164-
@test scalar_array_model_logp_and_grad[1] scalar_array_model_reference[1]
165-
@test scalar_array_model_logp_and_grad[2] scalar_array_model_reference[2]
166-
array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype)
167-
@test array_model_logp_and_grad[1] array_model_reference[1]
168-
@test array_model_logp_and_grad[2] array_model_reference[2]
169-
end
170-
end
73+
# @testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
74+
# # Failing model
75+
# t = 1:0.05:8
76+
# σ = 0.3
77+
# y = @. rand(sin(t) + Normal(0, σ))
78+
# @model function state_space(y, TT, ::Type{T}=Float64) where {T}
79+
# # Priors
80+
# α ~ Normal(y[1], 0.001)
81+
# τ ~ Exponential(1)
82+
# η ~ filldist(Normal(0, 1), TT - 1)
83+
# σ ~ Exponential(1)
84+
# # create latent variable
85+
# x = Vector{T}(undef, TT)
86+
# x[1] = α
87+
# for t in 2:TT
88+
# x[t] = x[t - 1] + η[t - 1] * τ
89+
# end
90+
# # measurement model
91+
# y ~ MvNormal(x, σ^2 * I)
92+
# return x
93+
# end
94+
# model = state_space(y, length(t))
95+
#
96+
# # Dummy sampling algorithm for testing. The test case can only be replicated
97+
# # with a custom sampler, it doesn't work with SampleFromPrior(). We need to
98+
# # overload assume so that model evaluation doesn't fail due to a lack
99+
# # of implementation
100+
# struct MyEmptyAlg end
101+
# DynamicPPL.assume(
102+
# ::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
103+
# ) = DynamicPPL.assume(dist, vn, vi)
104+
#
105+
# # Compiling the ReverseDiff tape used to fail here
106+
# spl = Sampler(MyEmptyAlg())
107+
# vi = VarInfo(model)
108+
# ldf = LogDensityFunction(
109+
# model, vi, SamplingContext(spl); adtype=AutoReverseDiff(; compile=true)
110+
# )
111+
# @test LogDensityProblems.logdensity_and_gradient(ldf, vi[:]) isa Any
112+
# end
113+
#
114+
# # Test that various different ways of specifying array types as arguments work with all
115+
# # ADTypes.
116+
# @testset "Array argument types" begin
117+
# test_m = randn(2, 3)
118+
#
119+
# function eval_logp_and_grad(model, m, adtype)
120+
# ldf = LogDensityFunction(model(); adtype=adtype)
121+
# return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
122+
# end
123+
#
124+
# @model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
125+
# m = Matrix{T}(undef, 2, 3)
126+
# return m ~ filldist(MvNormal(zeros(2), I), 3)
127+
# end
128+
#
129+
# scalar_matrix_model_reference = eval_logp_and_grad(
130+
# scalar_matrix_model, test_m, ref_adtype
131+
# )
132+
#
133+
# @model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
134+
# m = T(undef, 2, 3)
135+
# return m ~ filldist(MvNormal(zeros(2), I), 3)
136+
# end
137+
#
138+
# matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
139+
#
140+
# @model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
141+
# m = Array{T}(undef, 2, 3)
142+
# return m ~ filldist(MvNormal(zeros(2), I), 3)
143+
# end
144+
#
145+
# scalar_array_model_reference = eval_logp_and_grad(
146+
# scalar_array_model, test_m, ref_adtype
147+
# )
148+
#
149+
# @model function array_model(::Type{T}=Array{Float64}) where {T}
150+
# m = T(undef, 2, 3)
151+
# return m ~ filldist(MvNormal(zeros(2), I), 3)
152+
# end
153+
#
154+
# array_model_reference = eval_logp_and_grad(array_model, test_m, ref_adtype)
155+
#
156+
# @testset "$adtype" for adtype in test_adtypes
157+
# scalar_matrix_model_logp_and_grad = eval_logp_and_grad(
158+
# scalar_matrix_model, test_m, adtype
159+
# )
160+
# @test scalar_matrix_model_logp_and_grad[1] ≈ scalar_matrix_model_reference[1]
161+
# @test scalar_matrix_model_logp_and_grad[2] ≈ scalar_matrix_model_reference[2]
162+
# matrix_model_logp_and_grad = eval_logp_and_grad(matrix_model, test_m, adtype)
163+
# @test matrix_model_logp_and_grad[1] ≈ matrix_model_reference[1]
164+
# @test matrix_model_logp_and_grad[2] ≈ matrix_model_reference[2]
165+
# scalar_array_model_logp_and_grad = eval_logp_and_grad(
166+
# scalar_array_model, test_m, adtype
167+
# )
168+
# @test scalar_array_model_logp_and_grad[1] ≈ scalar_array_model_reference[1]
169+
# @test scalar_array_model_logp_and_grad[2] ≈ scalar_array_model_reference[2]
170+
# array_model_logp_and_grad = eval_logp_and_grad(array_model, test_m, adtype)
171+
# @test array_model_logp_and_grad[1] ≈ array_model_reference[1]
172+
# @test array_model_logp_and_grad[2] ≈ array_model_reference[2]
173+
# end
174+
# end
171175
end

test/runtests.jl

+49-49
Original file line numberDiff line numberDiff line change
@@ -45,60 +45,60 @@ include("test_util.jl")
4545
# groups are chosen to make both groups take roughly the same amount of
4646
# time, but beyond that there is no particular reason for the split.
4747
if GROUP == "All" || GROUP == "Group1"
48-
if AQUA
49-
include("Aqua.jl")
50-
end
51-
include("utils.jl")
52-
include("compiler.jl")
53-
include("varnamedvector.jl")
54-
include("varinfo.jl")
55-
include("simple_varinfo.jl")
56-
include("model.jl")
57-
include("sampler.jl")
58-
include("independence.jl")
59-
include("distribution_wrappers.jl")
60-
include("logdensityfunction.jl")
61-
include("linking.jl")
62-
include("serialization.jl")
63-
include("pointwise_logdensities.jl")
64-
include("lkj.jl")
65-
include("contexts.jl")
66-
include("context_implementations.jl")
67-
include("threadsafe.jl")
68-
include("debug_utils.jl")
69-
include("deprecated.jl")
48+
# if AQUA
49+
# include("Aqua.jl")
50+
# end
51+
# include("utils.jl")
52+
# include("compiler.jl")
53+
# include("varnamedvector.jl")
54+
# include("varinfo.jl")
55+
# include("simple_varinfo.jl")
56+
# include("model.jl")
57+
# include("sampler.jl")
58+
# include("independence.jl")
59+
# include("distribution_wrappers.jl")
60+
# include("logdensityfunction.jl")
61+
# include("linking.jl")
62+
# include("serialization.jl")
63+
# include("pointwise_logdensities.jl")
64+
# include("lkj.jl")
65+
# include("contexts.jl")
66+
# include("context_implementations.jl")
67+
# include("threadsafe.jl")
68+
# include("debug_utils.jl")
69+
# include("deprecated.jl")
7070
end
7171

7272
if GROUP == "All" || GROUP == "Group2"
73-
@testset "compat" begin
74-
include(joinpath("compat", "ad.jl"))
75-
end
76-
@testset "extensions" begin
77-
include("ext/DynamicPPLMCMCChainsExt.jl")
78-
include("ext/DynamicPPLJETExt.jl")
79-
end
73+
# @testset "compat" begin
74+
# include(joinpath("compat", "ad.jl"))
75+
# end
76+
# @testset "extensions" begin
77+
# include("ext/DynamicPPLMCMCChainsExt.jl")
78+
# include("ext/DynamicPPLJETExt.jl")
79+
# end
8080
@testset "ad" begin
81-
include("ext/DynamicPPLForwardDiffExt.jl")
82-
include("ext/DynamicPPLMooncakeExt.jl")
81+
# include("ext/DynamicPPLForwardDiffExt.jl")
82+
# include("ext/DynamicPPLMooncakeExt.jl")
8383
include("ad.jl")
8484
end
85-
@testset "prob and logprob macro" begin
86-
@test_throws ErrorException prob"..."
87-
@test_throws ErrorException logprob"..."
88-
end
89-
@testset "doctests" begin
90-
DocMeta.setdocmeta!(
91-
DynamicPPL,
92-
:DocTestSetup,
93-
:(using DynamicPPL, Distributions);
94-
recursive=true,
95-
)
96-
doctestfilters = [
97-
# Ignore the source of a warning in the doctest output, since this is dependent on host.
98-
# This is a line that starts with "└ @ " and ends with the line number.
99-
r"└ @ .+:[0-9]+",
100-
]
101-
doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
102-
end
85+
# @testset "prob and logprob macro" begin
86+
# @test_throws ErrorException prob"..."
87+
# @test_throws ErrorException logprob"..."
88+
# end
89+
# @testset "doctests" begin
90+
# DocMeta.setdocmeta!(
91+
# DynamicPPL,
92+
# :DocTestSetup,
93+
# :(using DynamicPPL, Distributions);
94+
# recursive=true,
95+
# )
96+
# doctestfilters = [
97+
# # Ignore the source of a warning in the doctest output, since this is dependent on host.
98+
# # This is a line that starts with "└ @ " and ends with the line number.
99+
# r"└ @ .+:[0-9]+",
100+
# ]
101+
# doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters)
102+
# end
103103
end
104104
end

0 commit comments

Comments
 (0)