Skip to content

Commit 9259e4a

Browse files
committed
fix and add tests
1 parent 54cdd12 commit 9259e4a

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

src/layers/normalise.jl

+8-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,14 @@ istraining() = false
22

33
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

5-
_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)
5+
_isactive(m) = Bool(something(m.active, istraining()))
66

7-
ChainRulesCore.@non_differentiable _isactive(::Any)
7+
# Avoids instabilities from differentiating through getproperty(m, :active)
8+
function ChainRulesCore.rrule(::typeof(_isactive), m)
9+
training, _ = rrule(istraining)
10+
_isactive_pullback(_) = (NoTangent(), NoTangent())
11+
return Bool(something(m.active, training)), _isactive_pullback
12+
end
813

914
_dropout_shape(s, ::Colon) = size(s)
1015
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@@ -59,7 +64,7 @@ end
5964

6065
function (pb::DropoutPullback)(dy)
6166
dx = pb.project(_apply_mask(dy, pb.mask))
62-
return (NoTangent(), NoTangent(), dx, NoTangent())
67+
return (NoTangent(), NoTangent(), dx, NoTangent(), NoTangent(), NoTangent())
6368
end
6469

6570
_apply_mask(x, ::Nothing) = x

test/layers/normalisation.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
7373
@test cpu(m).rng === only(values(rng_kwargs))
7474
end
7575
end
76-
76+
7777
for active in (true, false)
7878
m = Dropout(0.5, :, active)
79-
@inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
79+
_, back = @inferred pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
8080
@inferred back(ones(10)) # Array{Float64}
8181
end
8282
end
@@ -353,9 +353,9 @@ end
353353
x = rand(2)
354354
m = LayerNorm(2, tanh)
355355
@test m(x) tanh.(Flux.normalise(x, dims=1))
356-
@inferred _, back = pullback(summ, x)
357-
@inferred back(1.0)
358-
356+
_, back = @inferred pullback(|>, x, m)
357+
# TODO needs https://github.com/FluxML/Zygote.jl/pull/1248
358+
# @inferred back(1.0)
359359

360360
x = rand(2,3,4,5)
361361
@test LayerNorm((2,3))(x) Flux.normalise(x, dims=(1,2))

0 commit comments

Comments
 (0)