@@ -2,9 +2,14 @@ istraining() = false
2
2
3
3
ChainRulesCore. rrule (:: typeof (istraining)) = true , _ -> (NoTangent (),)
4
4
5
- _isactive (m) = isnothing ( m. active) ? istraining () : Bool (m . active )
5
+ _isactive (m) = Bool ( something ( m. active, istraining ()) )
6
6
7
- ChainRulesCore. @non_differentiable _isactive (:: Any )
7
+ # Avoids instabilities from differentiating through getproperty(m, :active)
8
+ function ChainRulesCore. rrule (:: typeof (_isactive), m)
9
+ training, _ = rrule (istraining)
10
+ _isactive_pullback (_) = (NoTangent (), NoTangent ())
11
+ return Bool (something (m. active, training)), _isactive_pullback
12
+ end
8
13
9
14
_dropout_shape (s, :: Colon ) = size (s)
10
15
_dropout_shape (s, dims) = tuple ((i ∉ dims ? 1 : si for (i, si) ∈ enumerate (size (s))). .. )
59
64
60
65
function (pb:: DropoutPullback )(dy)
61
66
dx = pb. project (_apply_mask (dy, pb. mask))
62
- return (NoTangent (), NoTangent (), dx, NoTangent ())
67
+ return (NoTangent (), NoTangent (), dx, NoTangent (), NoTangent (), NoTangent () )
63
68
end
64
69
65
70
_apply_mask (x, :: Nothing ) = x
0 commit comments