Skip to content

Commit 682958a

Browse files
authored
[AutoDiff] Add Differentiable.zeroTangentVectorInitializer. (swiftlang#28416)
The `Differentiable.zeroTangentVectorInitializer` requirement will enable correct, efficient reverse-mode differentiation of struct property accesses. A `zeroTangentVectorInitializer` closure is more efficient than a `zeroTangentVector` computed property, which would always capture `self`. Add `Differentiable.zeroTangentVector` default instance property, which returns `self.zeroTangentVectorInitializer()`. Todos: - Implement derived conformances for `zeroTangentVectorInitializer`. - Implement differentiation transform support for `zeroTangentVector` and struct/tuple projection instructions.
1 parent a946ec9 commit 682958a

File tree

1 file changed

+50
-17
lines changed

1 file changed

+50
-17
lines changed

stdlib/public/Differentiation/Differentiable.swift

+50-17
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,39 @@ public protocol Differentiable {
3636
mutating func move(along direction: TangentVector)
3737

3838
// SWIFT_ENABLE_TENSORFLOW
39-
/// A tangent vector such that `move(along: zeroTangentVector)` will not
40-
/// modify `self`.
41-
/// - Note: `zeroTangentVector` can be `TangentVector.zero` in most cases,
42-
/// but types whose tangent vectors depend on instance properties of `self`
43-
/// need to provide a different implementation. For example, the tangent
44-
/// vector of an `Array` depends on the array's `count`.
45-
@available(*, deprecated, message: """
46-
`zeroTangentVector` derivation has not been implemented; do not use \
47-
this property
48-
""")
49-
var zeroTangentVector: TangentVector { get }
39+
/// A closure that produces a zero tangent vector, capturing minimal
40+
/// necessary information from `self`.
41+
///
42+
/// `move(along: zeroTangentVectorInitializer())` should not modify
43+
/// `self`.
44+
///
45+
/// In some cases, the zero tangent vector of `self` is equal to
46+
/// `TangentVector.zero`. In other cases, the zero tangent vector depends on
47+
/// information in `self`, such as shape for an n-dimensional array type.
48+
/// For differentiable programming, it is more memory-efficient to define a
49+
/// custom `zeroTangentVectorInitializer` property which returns a closure
50+
/// that captures and uses only the necessary information to create a zero
51+
/// tangent vector. For example:
52+
///
53+
/// struct Vector {
54+
/// var scalars: [Float]
55+
/// var count: Int { scalars.count }
56+
/// init(scalars: [Float]) { ... }
57+
/// init(repeating repeatedElement: Float, count: Int) { ... }
58+
/// }
59+
///
60+
/// extension Vector: AdditiveArithmetic { ... }
61+
///
62+
/// extension Vector: Differentiable {
63+
/// typealias TangentVector = Vector
64+
///
65+
/// @noDerivative
66+
/// var zeroTangentVectorInitializer: () -> TangentVector {
67+
/// let count = self.count
68+
/// return { TangentVector(repeating: 0, count: count) }
69+
/// }
70+
/// }
71+
var zeroTangentVectorInitializer: () -> TangentVector { get }
5072
// SWIFT_ENABLE_TENSORFLOW END
5173
}
5274

@@ -59,12 +81,23 @@ public extension Differentiable where TangentVector == Self {
5981

6082
// SWIFT_ENABLE_TENSORFLOW
6183
public extension Differentiable {
62-
// This is a temporary solution that allows us to add `zeroTangentVector`
63-
// without implementing derived conformances. This property is marked
64-
// unavailable because it will produce incorrect results when tangent vectors
65-
// depend on instance properties of `self`.
66-
// FIXME: Implement derived conformance and remove this default
84+
// This is a temporary solution enabling the addition of
85+
// `zeroTangentVectorInitializer` without implementing derived conformances.
86+
// This property will produce incorrect results when tangent vectors depend
87+
// on instance-specific information from `self`.
88+
// FIXME: Implement derived conformances and remove this default
6789
// implementation.
68-
var zeroTangentVector: TangentVector { .zero }
90+
@available(*, deprecated, message: """
91+
`zeroTangentVectorInitializer` derivation has not been implemented; this \
92+
default implementation is not correct when tangent vectors depend on \
93+
instance-specific information from `self` and should not be used
94+
""")
95+
var zeroTangentVectorInitializer: () -> TangentVector {
96+
{ TangentVector.zero }
97+
}
98+
99+
/// A tangent vector initialized using `zeroTangentVectorInitializer`.
100+
/// `move(along: zeroTangentVector)` should not modify `self`.
101+
var zeroTangentVector: TangentVector { zeroTangentVectorInitializer() }
69102
}
70103
// SWIFT_ENABLE_TENSORFLOW END

0 commit comments

Comments
 (0)