1
1
2
2
# This file has integration tests for some rules defined in ChainRules.jl,
3
3
# especially those which aim to support higher derivatives, as properly
4
- # testing those is difficult.
4
+ # testing those is difficult. Organised according to the files in CR.jl.
5
+
6
+ using Diffractor, ForwardDiff, ChainRulesCore
7
+ using Test, LinearAlgebra
8
+
9
+ using Test: Threw, eval_test
5
10
6
- using Diffractor, ChainRulesCore, ForwardDiff
7
11
8
12
# ####
9
13
# #### Base/array.jl
@@ -13,7 +17,6 @@ using Diffractor, ChainRulesCore, ForwardDiff
13
17
14
18
15
19
16
-
17
20
# ####
18
21
# #### Base/arraymath.jl
19
22
# ####
@@ -33,21 +36,58 @@ using Diffractor, ChainRulesCore, ForwardDiff
33
36
# #### Base/indexing.jl
34
37
# ####
35
38
39
+ @testset " getindex, first" begin
40
+ @test_broken gradient (x -> sum (abs2, gradient (first, x)[1 ]), [1 ,2 ,3 ])[1 ] == [0 , 0 , 0 ] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
41
+ @test_broken gradient (x -> sum (abs2, gradient (sqrt∘ first, x)[1 ]), [1 ,2 ,3 ])[1 ] ≈ [- 0.25 , 0 , 0 ] # error() in perform_optic_transform(ff::Type{Diffractor.∂⃖recurse{2}}, args::Any)
42
+ @test gradient (x -> sum (abs2, gradient (x -> x[1 ]^ 2 , x)[1 ]), [1 ,2 ,3 ])[1 ] == [8 , 0 , 0 ]
43
+ @test_broken gradient (x -> sum (abs2, gradient (x -> sum (x[1 : 2 ])^ 2 , x)[1 ]), [1 ,2 ,3 ])[1 ] == [48 , 0 , 0 ] # MethodError: no method matching +(::Tuple{ZeroTangent, ZeroTangent}, ::Tuple{ZeroTangent, ZeroTangent})
44
+ end
36
45
37
-
46
+ @testset " eachcol etc" begin
47
+ @test gradient (m -> sum (prod, eachcol (m)), [1 2 3 ; 4 5 6 ])[1 ] == [4 5 6 ; 1 2 3 ]
48
+ @test gradient (m -> sum (first, eachcol (m)), [1 2 3 ; 4 5 6 ])[1 ] == [1 1 1 ; 0 0 0 ]
49
+ @test gradient (m -> sum (first (eachcol (m))), [1 2 3 ; 4 5 6 ])[1 ] == [1 0 0 ; 1 0 0 ]
50
+ @test_skip gradient (x -> sum (sin, gradient (m -> sum (first (eachcol (m))), x)[1 ]), [1 2 3 ; 4 5 6 ])[1 ] # MethodError: no method matching one(::Base.OneTo{Int64}), unzip_broadcast, split_bc_forwards
51
+ end
38
52
39
53
# ####
40
54
# #### Base/mapreduce.jl
41
55
# ####
42
56
57
+ @testset " sum" begin
58
+ @test gradient (x -> sum (abs2, gradient (sum, x)[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
59
+ @test gradient (x -> sum (abs2, gradient (x -> sum (abs2, x), x)[1 ]), [1 ,2 ,3 ])[1 ] == [8 ,16 ,24 ]
60
+
61
+ @test gradient (x -> sum (abs2, gradient (sum, x .^ 2 )[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
62
+ @test gradient (x -> sum (abs2, gradient (sum, x .^ 3 )[1 ]), [1 ,2 ,3 ])[1 ] == [0 ,0 ,0 ]
63
+ end
43
64
65
+ @testset " foldl" begin
66
+
67
+ @test gradient (x -> foldl (* , x), [1 ,2 ,3 ,4 ])[1 ] == [24.0 , 12.0 , 8.0 , 6.0 ]
68
+ @test gradient (x -> foldl (* , x; init= 5 ), [1 ,2 ,3 ,4 ])[1 ] == [120.0 , 60.0 , 40.0 , 30.0 ]
69
+ @test gradient (x -> foldr (* , x), [1 ,2 ,3 ,4 ])[1 ] == [24 , 12 , 8 , 6 ]
70
+
71
+ @test gradient (x -> foldl (* , x), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (24.0 , 12.0 , 8.0 , 6.0 )
72
+ @test_broken gradient (x -> foldl (* , x; init= 5 ), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (120.0 , 60.0 , 40.0 , 30.0 ) # does not return a Tangent
73
+ @test gradient (x -> foldl (* , x; init= 5 ), (1 ,2 ,3 ,4 )) |> only |> Tuple == (120.0 , 60.0 , 40.0 , 30.0 )
74
+ @test_broken gradient (x -> foldr (* , x), (1 ,2 ,3 ,4 ))[1 ] == Tangent {NTuple{4,Int}} (24 , 12 , 8 , 6 )
75
+ @test gradient (x -> foldr (* , x), (1 ,2 ,3 ,4 )) |> only |> Tuple == (24 , 12 , 8 , 6 )
76
+
77
+ end
44
78
45
79
46
80
# ####
47
81
# #### LinearAlgebra/dense.jl
48
82
# ####
49
83
50
84
85
+ @testset " dot" begin
86
+
87
+ @test gradient (x -> dot (x, [1 ,2 ,3 ])^ 2 , [4 ,5 ,6 ])[1 ] == [64 ,128 ,192 ]
88
+ @test_broken gradient (x -> sum (gradient (x -> dot (x, [1 ,2 ,3 ])^ 2 , x)[1 ]), [4 ,5 ,6 ])[1 ] == [12 ,24 ,36 ] # MethodError: no method matching +(::Tuple{Tangent{ChainRules.var
89
+
90
+ end
51
91
52
92
53
93
# ####
0 commit comments