@@ -36,17 +36,39 @@ public protocol Differentiable {
36
36
mutating func move( along direction: TangentVector )
37
37
38
38
// SWIFT_ENABLE_TENSORFLOW
39
- /// A tangent vector such that `move(along: zeroTangentVector)` will not
40
- /// modify `self`.
41
- /// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
42
- /// but types whose tangent vectors depend on instance properties of `self`
43
- /// need to provide a different implementation. For example, the tangent
44
- /// vector of an `Array` depends on the array's `count`.
45
- @available ( * , deprecated, message: """
46
- `zeroTangentVector` derivation has not been implemented; do not use \
47
- this property
48
- """ )
49
- var zeroTangentVector : TangentVector { get }
39
+ /// A closure that produces a zero tangent vector, capturing minimal
40
+ /// necessary information from `self`.
41
+ ///
42
+ /// `move(along: zeroTangentVectorInitializer())` should not modify
43
+ /// `self`.
44
+ ///
45
+ /// In some cases, the zero tangent vector of `self` is equal to
46
+ /// `TangentVector.zero`. In other cases, the zero tangent vector depends on
47
+ /// information in `self`, such as shape for an n-dimensional array type.
48
+ /// For differentiable programming, it is more memory-efficient to define a
49
+ /// custom `zeroTangentVectorInitializer` property which returns a closure
50
+ /// that captures and uses only the necessary information to create a zero
51
+ /// tangent vector. For example:
52
+ ///
53
+ /// struct Vector {
54
+ /// var scalars: [Float]
55
+ /// var count: Int { scalars.count }
56
+ /// init(scalars: [Float]) { ... }
57
+ /// init(repeating repeatedElement: Float, count: Int) { ... }
58
+ /// }
59
+ ///
60
+ /// extension Vector: AdditiveArithmetic { ... }
61
+ ///
62
+ /// extension Vector: Differentiable {
63
+ /// typealias TangentVector = Vector
64
+ ///
65
+ /// @noDerivative
66
+ /// var zeroTangentVectorInitializer: () -> TangentVector {
67
+ /// let count = self.count
68
+ /// return { TangentVector(repeating: 0, count: count) }
69
+ /// }
70
+ /// }
71
+ var zeroTangentVectorInitializer : ( ) -> TangentVector { get }
50
72
// SWIFT_ENABLE_TENSORFLOW END
51
73
}
52
74
@@ -59,12 +81,23 @@ public extension Differentiable where TangentVector == Self {
59
81
60
82
// SWIFT_ENABLE_TENSORFLOW
61
83
public extension Differentiable {
62
- // This is a temporary solution that allows us to add `zeroTangentVector`
63
- // without implementing derived conformances. This property is marked
64
- // unavailable because it will produce incorrect results when tangent vectors
65
- // depend on instance properties of `self`.
66
- // FIXME: Implement derived conformance and remove this default
84
+ // This is a temporary solution enabling the addition of
85
+ // `zeroTangentVectorInitializer` without implementing derived conformances.
86
+ // This property will produce incorrect results when tangent vectors depend
87
+ // on instance-specific information from `self`.
88
+ // FIXME: Implement derived conformances and remove this default
67
89
// implementation.
68
- var zeroTangentVector : TangentVector { . zero }
90
+ @available ( * , deprecated, message: """
91
+ `zeroTangentVectorInitializer` derivation has not been implemented; this \
92
+ default implementation is not correct when tangent vectors depend on \
93
+ instance-specific information from `self` and should not be used
94
+ """ )
95
+ var zeroTangentVectorInitializer : ( ) -> TangentVector {
96
+ { TangentVector . zero }
97
+ }
98
+
99
+ /// A tangent vector initialized using `zeroTangentVectorInitializer`.
100
+ /// `move(along: zeroTangentVector)` should not modify `self`.
101
+ var zeroTangentVector : TangentVector { zeroTangentVectorInitializer ( ) }
69
102
}
70
103
// SWIFT_ENABLE_TENSORFLOW END
0 commit comments