Skip to content

Commit 1250a47

Browse files
committed
fixes + tests
1 parent 5877d36 commit 1250a47

File tree

9 files changed

+830
-139
lines changed

9 files changed

+830
-139
lines changed

src/extra_rules.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,7 @@ end
246246
@ChainRules.non_differentiable Base.:(|)(a::Integer, b::Integer)
247247
@ChainRules.non_differentiable Base.throw(err)
248248
@ChainRules.non_differentiable Core.Compiler.return_type(args...)
249+
249250
ChainRulesCore.canonicalize(::NoTangent) = NoTangent()
250251

251252
# Disable thunking at higher order (TODO: These should go into ChainRulesCore)
@@ -259,3 +260,15 @@ function ChainRulesCore.rrule(::DiffractorRuleConfig, ::Type{InplaceableThunk},
259260
end
260261

261262
Base.real(z::NoTangent) = z # TODO should be in CRC, https://github.com/JuliaDiff/ChainRulesCore.jl/pull/581
263+
264+
# Avoid https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
265+
ChainRulesCore._backing_error(P::Type{<:Base.Pairs}, G::Type{<:NamedTuple}, E::Type{<:AbstractDict}) = nothing
266+
267+
# For gradient(pow_simd, 2, 3)[1] in zygote_features.jl
268+
ChainRulesCore.@non_differentiable Base.SimdLoop.simd_inner_length(::Any, ::Any)
269+
270+
# This allows fill!(similar([1,2,3], ZeroTangent), false)
271+
function Base.convert(::Type{ZeroTangent}, x::Number)
272+
iszero(x) || throw(InexactError(:convert, ZeroTangent, x))
273+
ZeroTangent()
274+
end

src/runtime.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ accum(x::Tangent{T}, y::Tangent) where T = _tangent(T, accum(backing(x), backing
2828
_tangent(::Type{T}, z) where T = Tangent{T,typeof(z)}(z)
2929
_tangent(::Type, ::NamedTuple{()}) = NoTangent()
3030
_tangent(::Type, ::NamedTuple{<:Any, <:Tuple{Vararg{AbstractZero}}}) = NoTangent()
31+
32+
function accum(x::Tangent{T}, y::Tuple) where {T<:Tuple}
33+
# @warn "gradient is both a Tangent and a Tuple" x y
34+
_tangent(T, accum(backing(x), y))
35+
end
36+
accum(x::Tuple, y::Tangent{<:Tuple}) = accum(y, x)

test/chainrules.jl

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11

22
# This file has integration tests for some rules defined in ChainRules.jl,
33
# especially those which aim to support higher derivatives, as properly
4-
# testing those is difficult.
4+
# testing those is difficult. Organised according to the files in CR.jl.
5+
6+
using Diffractor, ForwardDiff, ChainRulesCore
7+
using Test, LinearAlgebra
8+
9+
using Test: Threw, eval_test
510

6-
using Diffractor, ChainRulesCore, ForwardDiff
711

812
#####
913
##### Base/array.jl
@@ -13,7 +17,6 @@ using Diffractor, ChainRulesCore, ForwardDiff
1317

1418

1519

16-
1720
#####
1821
##### Base/arraymath.jl
1922
#####
@@ -33,21 +36,58 @@ using Diffractor, ChainRulesCore, ForwardDiff
3336
##### Base/indexing.jl
3437
#####
3538

39+
@testset "getindex, first" begin
40+
@test_broken gradient(x -> sum(abs2, gradient(first, x)[1]), [1,2,3])[1] == [0, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
41+
@test_broken gradient(x -> sum(abs2, gradient(sqrtfirst, x)[1]), [1,2,3])[1] [-0.25, 0, 0] # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any)
42+
@test gradient(x -> sum(abs2, gradient(x -> x[1]^2, x)[1]), [1,2,3])[1] == [8, 0, 0]
43+
@test_broken gradient(x -> sum(abs2, gradient(x -> sum(x[1:2])^2, x)[1]), [1,2,3])[1] == [48, 0, 0] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
44+
end
3645

37-
46+
@testset "eachcol etc" begin
47+
@test gradient(m -> sum(prod, eachcol(m)), [1 2 3; 4 5 6])[1] == [4 5 6; 1 2 3]
48+
@test gradient(m -> sum(first, eachcol(m)), [1 2 3; 4 5 6])[1] == [1 1 1; 0 0 0]
49+
@test gradient(m -> sum(first(eachcol(m))), [1 2 3; 4 5 6])[1] == [1 0 0; 1 0 0]
50+
@test_skip gradient(x -> sum(sin, gradient(m -> sum(first(eachcol(m))), x)[1]), [1 2 3; 4 5 6])[1] # MethodError: no method matching one(::Base.OneTo{Int64}), unzip_broadcast, split_bc_forwards
51+
end
3852

3953
#####
4054
##### Base/mapreduce.jl
4155
#####
4256

57+
@testset "sum" begin
58+
@test gradient(x -> sum(abs2, gradient(sum, x)[1]), [1,2,3])[1] == [0,0,0]
59+
@test gradient(x -> sum(abs2, gradient(x -> sum(abs2, x), x)[1]), [1,2,3])[1] == [8,16,24]
60+
61+
@test gradient(x -> sum(abs2, gradient(sum, x .^ 2)[1]), [1,2,3])[1] == [0,0,0]
62+
@test gradient(x -> sum(abs2, gradient(sum, x .^ 3)[1]), [1,2,3])[1] == [0,0,0]
63+
end
4364

65+
@testset "foldl" begin
66+
67+
@test gradient(x -> foldl(*, x), [1,2,3,4])[1] == [24.0, 12.0, 8.0, 6.0]
68+
@test gradient(x -> foldl(*, x; init=5), [1,2,3,4])[1] == [120.0, 60.0, 40.0, 30.0]
69+
@test gradient(x -> foldr(*, x), [1,2,3,4])[1] == [24, 12, 8, 6]
70+
71+
@test gradient(x -> foldl(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24.0, 12.0, 8.0, 6.0)
72+
@test_broken gradient(x -> foldl(*, x; init=5), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(120.0, 60.0, 40.0, 30.0) # does not return a Tangent
73+
@test gradient(x -> foldl(*, x; init=5), (1,2,3,4)) |> only |> Tuple == (120.0, 60.0, 40.0, 30.0)
74+
@test_broken gradient(x -> foldr(*, x), (1,2,3,4))[1] == Tangent{NTuple{4,Int}}(24, 12, 8, 6)
75+
@test gradient(x -> foldr(*, x), (1,2,3,4)) |> only |> Tuple == (24, 12, 8, 6)
76+
77+
end
4478

4579

4680
#####
4781
##### LinearAlgebra/dense.jl
4882
#####
4983

5084

85+
@testset "dot" begin
86+
87+
@test gradient(x -> dot(x, [1,2,3])^2, [4,5,6])[1] == [64,128,192]
88+
@test_broken gradient(x -> sum(gradient(x -> dot(x, [1,2,3])^2, x)[1]), [4,5,6])[1] == [12,24,36] # MethodError: no method matching +(::Tuple{Tangent{ChainRules.var
89+
90+
end
5191

5292

5393
#####

test/diffractor_01.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
# The rest of this file is unchanged, except the very end,
2-
# but IMO we should move these tests to a new file.
1+
# This file has tests written specifically for Diffractor v0.1,
2+
# which were in runtests.jl before PR 73 moved them all.
33

4-
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
4+
using Test
5+
6+
using Diffractor
7+
using Diffractor: ∂⃖, DiffractorRuleConfig
58

6-
using Diffractor: var"'", ∂⃖, DiffractorRuleConfig
79
using ChainRules
810
using ChainRulesCore
911
using ChainRulesCore: ZeroTangent, NoTangent, frule_via_ad, rrule_via_ad
1012
using Symbolics
11-
using LinearAlgebra
1213

14+
using LinearAlgebra
1315

16+
# Loading Diffractor: var"'" globally will break many tests above, which use it for adjoint.
1417
const fwd = Diffractor.PrimeDerivativeFwd
1518
const bwd = Diffractor.PrimeDerivativeBack
1619

@@ -48,8 +51,10 @@ ChainRules.rrule(::typeof(my_tuple), args...) = args, Δ->Core.tuple(NoTangent()
4851
@test isequal(simplify(x8), simplify((η +*ζ) +*ϵ) +*+*β))))*exp(ω)))
4952

5053
# Minimal 2-nd order forward smoke test
51-
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
52-
Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
54+
let var"'" = Diffractor.PrimeDerivativeBack
55+
@test Diffractor.∂☆{2}()(Diffractor.ZeroBundle{2}(sin),
56+
Diffractor.TangentBundle{2}(1.0, (1.0, 1.0, 0.0)))[Diffractor.CanonicalTangentIndex(1)] == sin'(1.0)
57+
end
5358

5459
function simple_control_flow(b, x)
5560
if b
@@ -269,7 +274,7 @@ end
269274
@testset "broadcast, 2nd order" begin
270275
@test gradient(x -> gradient(y -> sum(y .* y), x)[1] |> sum, [1,2,3.0])[1] == [2,2,2] # calls "split broadcasting generic" with f = unthunk
271276
@test gradient(x -> gradient(y -> sum(y .* x), x)[1].^3 |> sum, [1,2,3.0])[1] == [3,12,27]
272-
@test_broken gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12] # Control flow support not fully implemented yet for higher-order
277+
@test gradient(x -> gradient(y -> sum(y .* 2 .* y'), x)[1] |> sum, [1,2,3.0])[1] == [12, 12, 12]
273278

274279
@test_broken gradient(x -> sum(gradient(x -> sum(x .^ 2 .+ x'), x)[1]), [1,2,3.0])[1] == [6,6,6] # BoundsError: attempt to access 18-element Vector{Core.Compiler.BasicBlock} at index [0]
275280
@test_broken gradient(x -> sum(gradient(x -> sum((x .+ 1) .* x .- x), x)[1]), [1,2,3.0])[1] == [2,2,2]
@@ -283,3 +288,10 @@ end
283288
@test_broken gradient(z -> gradient(x -> sum((y -> (x^2*y)).([1,2,3])), z)[1], 5.0) == (12.0,)
284289
end
285290

291+
# Issue 67, due to https://github.com/JuliaDiff/ChainRulesCore.jl/pull/495
292+
@test gradient(identitysqrt, 4.0) == (0.25,)
293+
294+
# Issue #70 - Complex & getproperty
295+
@test_broken gradient(x -> x.re, 2+3im)[1] == 1
296+
@test_broken gradient(x -> abs2(x * x.re), 4+5im)[1] == 456 + 160im
297+
@test gradient(x -> abs2(x * real(x)), 4+5im)[1] == 456 + 160im

0 commit comments

Comments
 (0)