Skip to content

Commit 54cdd12

Browse files
committed
Improve type stability of LayerNorm and Dropout
1 parent 952c4a5 commit 54cdd12

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

src/Flux.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ using MacroTools: @forward
99
using MLUtils
1010
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1111

12-
using Zygote, ChainRulesCore
12+
using ChainRulesCore
13+
14+
using Zygote
1315
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1416
export gradient
1517

src/layers/normalise.jl

+41-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ istraining() = false
22

33
ChainRulesCore.rrule(::typeof(istraining)) = true, _ -> (NoTangent(),)
44

5-
_isactive(m) = isnothing(m.active) ? istraining() : m.active
5+
_isactive(m) = isnothing(m.active) ? istraining() : Bool(m.active)
6+
7+
ChainRulesCore.@non_differentiable _isactive(::Any)
68

79
_dropout_shape(s, ::Colon) = size(s)
810
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
@@ -31,26 +33,51 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
3133
3234
The [`Dropout`](@ref) layer is what you should use in most scenarios.
3335
"""
34-
function dropout(rng, x, p; dims=:, active::Bool=true)
35-
active || return x
36-
y = dropout_mask(rng, x, p, dims=dims)
37-
return x .* y
38-
end
36+
dropout(rng, x, p; dims=:, active::Bool=true) = _dropout(rng, x, p, dims, active)
3937
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)
4038

41-
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
42-
dropout_mask(rng, x::CuArray, p; kwargs...) =
43-
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
44-
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
45-
function _dropout_mask(rng, x, p; dims=:)
39+
# Internal function without kwargs to keep Zygote generated code type stable
40+
function _dropout(rng, x, p, dims, active)
41+
mask = active ? dropout_mask(rng, x, p, dims) : nothing
42+
return _apply_mask(x, mask)
43+
end
44+
45+
function ChainRulesCore.rrule(::typeof(_dropout), rng, x, p, dims, active)
46+
mask = active ? dropout_mask(rng, x, p, dims) : nothing
47+
# Required because we don't always call dropout_mask
48+
MT = Core.Compiler.return_type(dropout_mask, Tuple{typeof(rng),typeof(x),typeof(p),typeof(dims)})
49+
project_x = ProjectTo(x)
50+
return _apply_mask(x, mask), DropoutPullback{MT,typeof(project_x)}(mask, project_x)
51+
end
52+
53+
# Also needed for type stability. Otherwise inference lifts the Union into a
54+
# Union{pullback{Nothing}, pullback{AbstractArray}}
55+
struct DropoutPullback{M<:AbstractArray,P<:ProjectTo{AbstractArray}}
56+
mask::Union{Nothing,M}
57+
project::P
58+
end
59+
60+
function (pb::DropoutPullback)(dy)
61+
dx = pb.project(_apply_mask(dy, pb.mask))
62+
return (NoTangent(), NoTangent(), dx, NoTangent())
63+
end
64+
65+
_apply_mask(x, ::Nothing) = x
66+
_apply_mask(x, mask) = x .* mask
67+
68+
dropout_mask(rng::CUDA.RNG, x::CuArray, p, dims) = _dropout_mask(rng, x, p, dims)
69+
dropout_mask(rng, x::CuArray, p, dims) =
70+
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only supports CUDA.RNG for CuArrays."))
71+
dropout_mask(rng, x, p, dims) = _dropout_mask(rng, x, p, dims)
72+
function _dropout_mask(rng, x, p, dims)
4673
realfptype = float(real(eltype(x)))
4774
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
4875
y .= _dropout_kernel.(y, p, 1 - p)
4976
return y
5077
end
5178

5279
# TODO move this to NNlib
53-
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any)
80+
ChainRulesCore.@non_differentiable dropout_mask(::Any, ::Any, ::Any, ::Any)
5481

5582
"""
5683
Dropout(p; dims=:, rng = rng_from_array())
@@ -82,10 +109,7 @@ end
82109
@functor Dropout
83110
trainable(a::Dropout) = (;)
84111

85-
function (a::Dropout)(x)
86-
_isactive(a) || return x
87-
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
88-
end
112+
(a::Dropout)(x) = _dropout(a.rng, x, a.p, a.dims, _isactive(a))
89113

90114
testmode!(m::Dropout, mode=true) =
91115
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
@@ -172,7 +196,7 @@ LayerNorm(size_act...; kw...) = LayerNorm(Int.(size_act[1:end-1]), size_act[end]
172196

173197
@functor LayerNorm
174198

175-
(a::LayerNorm)(x) = a.diag(normalise(x, dims=1:length(a.size), ϵ=a.ϵ))
199+
(a::LayerNorm)(x) = a.diag(_normalize(x, 1:length(a.size), a.ϵ))
176200

177201
function Base.show(io::IO, l::LayerNorm)
178202
print(io, "LayerNorm(", join(l.size, ", "))

src/layers/stateless.jl

+32-6
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,41 @@ function flatten(x::AbstractArray)
2626
return reshape(x, :, size(x)[end])
2727
end
2828

29+
# Utils for LayerNorm internals.
30+
# Most of these are required for better performance and type stability under AD.
31+
# In an ideal world, we'd just have normalise.
32+
33+
function _mean_std(x::AbstractArray, dims)
34+
μ = mean(x, dims=dims)
35+
σ = std(x, dims=dims, mean=μ, corrected=false)
36+
return μ, σ
37+
end
38+
39+
function ChainRulesCore.rrule(::typeof(_mean_std), x::AbstractArray, dims)
40+
μ, mean_pullback = ChainRulesCore.rrule(mean, x, dims=dims)
41+
σ, std_pullback = ChainRulesCore.rrule(std, x, dims=dims, mean=μ, corrected=false)
42+
function _mean_std_pullback((dμ, dσ))
43+
dx = ChainRulesCore.add!!(std_pullback(dσ)[2], mean_pullback(dμ)[2])
44+
return (NoTangent(), dx, NoTangent())
45+
end
46+
47+
return (μ, σ), _mean_std_pullback
48+
end
49+
50+
_zscore(x, μ, σ, ϵ) = (x - μ) /+ ϵ)
51+
52+
# We don't define a rrule for the whole function because we want
53+
# AD to figure out the _zscore broadcast for us.
54+
function _normalize(x::AbstractArray, dims, ϵ)
55+
μ, σ = _mean_std(x, dims)
56+
return _zscore.(x, μ, σ, ϵ)
57+
end
58+
2959
"""
3060
normalise(x; dims=ndims(x), ϵ=1e-5)
3161
3262
Normalise `x` to mean 0 and standard deviation 1 across the dimension(s) given by `dims`.
33-
Per default, `dims` is the last dimension.
63+
Per default, `dims` is the last dimension.
3464
`ϵ` is a small additive factor added to the denominator for numerical stability.
3565
"""
36-
@inline function normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5))
37-
μ = mean(x, dims=dims)
38-
σ = std(x, dims=dims, mean=μ, corrected=false)
39-
return @. (x - μ) /+ ϵ)
40-
end
66+
@inline normalise(x::AbstractArray; dims=ndims(x), ϵ=ofeltype(x, 1e-5)) = _normalize(x, dims, ϵ)

test/layers/normalisation.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ evalwgrad(f, x...) = pullback(f, x...)[1]
7373
@test cpu(m).rng === only(values(rng_kwargs))
7474
end
7575
end
76+
77+
for active in (true, false)
78+
m = Dropout(0.5, :, active)
79+
@inferred _, back = pullback(m, rand(10)) # _, DropoutPullback{Array{Float64}}
80+
@inferred back(ones(10)) # Array{Float64}
81+
end
7682
end
7783

7884
@testset "AlphaDropout" begin
@@ -343,8 +349,13 @@ end
343349
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)
344350
x = rand(2,3,4,5)
345351
@test LayerNorm(2)(x) Flux.normalise(x, dims=1)
352+
346353
x = rand(2)
347-
@test LayerNorm(2, tanh)(x) tanh.(Flux.normalise(x, dims=1))
354+
m = LayerNorm(2, tanh)
355+
@test m(x) tanh.(Flux.normalise(x, dims=1))
356+
@inferred _, back = pullback(summ, x)
357+
@inferred back(1.0)
358+
348359

349360
x = rand(2,3,4,5)
350361
@test LayerNorm((2,3))(x) Flux.normalise(x, dims=(1,2))

0 commit comments

Comments
 (0)