13
13
// limitations under the License.
14
14
15
15
// TODO: Re-enable this for the stock toolchain when it can be realigned with VectorProtocol.
16
- #if !TENSORFLOW_USE_STANDARD_TOOLCHAIN
17
16
import TensorFlow
18
17
import _Differentiation
19
18
@@ -41,27 +40,27 @@ fileprivate func mustOverride(function: StaticString = #function, file: StaticSt
41
40
/// - Input: the input type of the underlying layar
42
41
/// - Output: the output type of the underlying layer
43
42
/// - Scalar: the scalar type of the underlying tangent vector
44
- internal class AnyLayerBox < Input: Differentiable , Output: Differentiable , Scalar : FloatingPoint & ElementaryFunctions > {
43
+ internal class AnyLayerBox < Input: Differentiable , Output: Differentiable > {
45
44
/// The underlying layer, type-erased to `Any`.
46
45
var typeErasedBase : Any {
47
46
mustOverride ( )
48
47
}
49
48
50
49
/// Returns the underlying layer unboxed to the given type, if possible.
51
50
func unboxed< U: Layer > ( to type: U . Type ) -> U ?
52
- where U. TangentVector. VectorSpaceScalar == Scalar {
51
+ where U. TangentVector. VectorSpaceScalar == Float {
53
52
mustOverride ( )
54
53
}
55
54
56
55
// `Differentiable` requirements.
57
56
/// Moves `self` along the given direction. In Riemannian geometry, this is equivalent to exponential map, which moves `self` on the geodesic surface along the given tangent vector.
58
- func _move( along direction: AnyLayerTangentVector < Scalar > ) {
57
+ func _move( along direction: AnyLayerTangentVector ) {
59
58
mustOverride ( )
60
59
}
61
60
62
61
// `EuclideanDifferentiable` requirements.
63
62
/// The differentiable vector component of `self`.
64
- var _differentiableVectorView : AnyLayerTangentVector < Scalar > {
63
+ var _differentiableVectorView : AnyLayerTangentVector {
65
64
mustOverride ( )
66
65
}
67
66
@@ -72,7 +71,7 @@ internal class AnyLayerBox<Input: Differentiable, Output: Differentiable, Scalar
72
71
}
73
72
74
73
func _vjpCallAsFunction( _ input: Input ) ->
75
- ( value: Output , pullback: ( Output . TangentVector ) -> ( AnyLayerTangentVector < Scalar > , Input . TangentVector ) ) {
74
+ ( value: Output , pullback: ( Output . TangentVector ) -> ( AnyLayerTangentVector , Input . TangentVector ) ) {
76
75
mustOverride ( )
77
76
}
78
77
@@ -84,14 +83,14 @@ internal class AnyLayerBox<Input: Differentiable, Output: Differentiable, Scalar
84
83
}
85
84
86
85
/// Creates a new box storing a copy of the underlying layer, used to preserve value semantics.
87
- func duplicate( ) -> AnyLayerBox < Input , Output , Scalar > {
86
+ func duplicate( ) -> AnyLayerBox < Input , Output > {
88
87
mustOverride ( )
89
88
}
90
89
}
91
90
92
91
/// A concrete implementation of the type-erased layer wrapper that forwards to an underlying layer.
93
- internal class ConcreteLayerBox < Underlying: Layer > : AnyLayerBox < Underlying . Input , Underlying . Output , Underlying . TangentVector . VectorSpaceScalar >
94
- where Underlying. TangentVector. VectorSpaceScalar: FloatingPoint & ElementaryFunctions {
92
+ internal class ConcreteLayerBox < Underlying: Layer > : AnyLayerBox < Underlying . Input , Underlying . Output >
93
+ where Underlying. TangentVector. VectorSpaceScalar == Float {
95
94
/// The underlying layer.
96
95
var underlying : Underlying
97
96
@@ -107,12 +106,12 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
107
106
108
107
/// Returns the underlying layer unboxed to the given type, if possible.
109
108
override func unboxed< U: Layer > ( to type: U . Type ) -> U ?
110
- where U. TangentVector. VectorSpaceScalar == Underlying . TangentVector . VectorSpaceScalar {
109
+ where U. TangentVector. VectorSpaceScalar == Float {
111
110
return ( self as? ConcreteLayerBox < U > ) ? . underlying
112
111
}
113
112
114
113
// `Differentiable` requirements.
115
- override func _move( along direction: AnyLayerTangentVector < Underlying . TangentVector . VectorSpaceScalar > ) {
114
+ override func _move( along direction: AnyLayerTangentVector ) {
116
115
if let scalarDirection = direction. box. getOpaqueScalar ( ) {
117
116
underlying. move ( along: Underlying . TangentVector. zero. adding ( scalarDirection) )
118
117
} else {
@@ -125,7 +124,7 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
125
124
}
126
125
127
126
// `EuclideanDifferentiable` requirements.
128
- public override var _differentiableVectorView : AnyLayerTangentVector < Underlying . TangentVector . VectorSpaceScalar > {
127
+ public override var _differentiableVectorView : AnyLayerTangentVector {
129
128
return AnyLayerTangentVector ( underlying. differentiableVectorView)
130
129
}
131
130
@@ -143,7 +142,7 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
143
142
override func _vjpCallAsFunction( _ input: Underlying . Input ) -> (
144
143
value: Underlying . Output ,
145
144
pullback: ( Underlying . Output . TangentVector ) ->
146
- ( AnyLayerTangentVector < Underlying . TangentVector . VectorSpaceScalar > , Underlying . Input . TangentVector )
145
+ ( AnyLayerTangentVector , Underlying . Input . TangentVector )
147
146
) {
148
147
let basePullback = valueWithPullback (
149
148
at: ModelAndInput ( model: underlying, input: input) ,
@@ -155,7 +154,7 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
155
154
pullback: { ( outTangent) in
156
155
let pairTangent = basePullback. pullback ( outTangent)
157
156
return (
158
- AnyLayerTangentVector < Underlying . TangentVector . VectorSpaceScalar > ( pairTangent. model) ,
157
+ AnyLayerTangentVector ( pairTangent. model) ,
159
158
pairTangent. input
160
159
)
161
160
}
@@ -164,12 +163,12 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
164
163
165
164
// `CopyableToDevice` requirements.
166
165
override func _copyToDevice( to device: Device ) ->
167
- AnyLayerBox < Underlying . Input , Underlying . Output , Underlying . TangentVector . VectorSpaceScalar > {
166
+ AnyLayerBox < Underlying . Input , Underlying . Output > {
168
167
return ConcreteLayerBox ( Underlying ( copying: underlying, to: device) )
169
168
}
170
169
171
170
override func duplicate( ) ->
172
- AnyLayerBox < Underlying . Input , Underlying . Output , Underlying . TangentVector . VectorSpaceScalar > {
171
+ AnyLayerBox < Underlying . Input , Underlying . Output > {
173
172
return ConcreteLayerBox ( underlying)
174
173
}
175
174
}
@@ -189,11 +188,10 @@ where Underlying.TangentVector.VectorSpaceScalar: FloatingPoint & ElementaryFunc
189
188
/// Type Parameters:
190
189
/// - Input: the input type of the underlying layar
191
190
/// - Output: the output type of the underlying layer
192
- /// - Scalar: the scalar type of the underlying tangent vector
193
- public struct AnyLayer < Input: Differentiable , Output: Differentiable , Scalar: FloatingPoint & ElementaryFunctions > : CopyableToDevice {
194
- internal var box : AnyLayerBox < Input , Output , Scalar >
191
+ public struct AnyLayer < Input: Differentiable , Output: Differentiable > : CopyableToDevice {
192
+ internal var box : AnyLayerBox < Input , Output >
195
193
196
- internal init ( box: AnyLayerBox < Input , Output , Scalar > ) {
194
+ internal init ( box: AnyLayerBox < Input , Output > ) {
197
195
self . box = box
198
196
}
199
197
@@ -205,7 +203,7 @@ public struct AnyLayer<Input: Differentiable, Output: Differentiable, Scalar: Fl
205
203
/// Creates a type-erased derivative from the given layer.
206
204
@differentiable
207
205
public init < Underlying: Layer > ( _ layer: Underlying )
208
- where Underlying. Input == Input , Underlying. Output == Output , Underlying. TangentVector. VectorSpaceScalar == Scalar {
206
+ where Underlying. Input == Input , Underlying. Output == Output , Underlying. TangentVector. VectorSpaceScalar == Float {
209
207
self . box = ConcreteLayerBox < Underlying > ( layer)
210
208
}
211
209
@@ -217,25 +215,25 @@ public struct AnyLayer<Input: Differentiable, Output: Differentiable, Scalar: Fl
217
215
@derivative ( of: init)
218
216
internal static func _vjpInit< T: Layer > (
219
217
_ base: T
220
- ) -> ( value: AnyLayer , pullback: ( AnyLayerTangentVector < Scalar > ) -> T . TangentVector )
221
- where T. Input == Input , T. Output == Output , T. TangentVector. VectorSpaceScalar == Scalar
218
+ ) -> ( value: AnyLayer , pullback: ( AnyLayerTangentVector ) -> T . TangentVector )
219
+ where T. Input == Input , T. Output == Output , T. TangentVector. VectorSpaceScalar == Float
222
220
{
223
- return ( AnyLayer < Input , Output , Scalar > ( base) , { v in v. unboxed ( as: T . TangentVector. self) ! } )
221
+ return ( AnyLayer < Input , Output > ( base) , { v in v. unboxed ( as: T . TangentVector. self) ! } )
224
222
}
225
223
226
224
@inlinable
227
225
@derivative ( of: init)
228
226
internal static func _jvpInit< T: Layer > (
229
227
_ base: T
230
228
) -> (
231
- value: AnyLayer , differential: ( T . TangentVector ) -> AnyLayerTangentVector < Scalar >
232
- ) where T. Input == Input , T. Output == Output , T. TangentVector. VectorSpaceScalar == Scalar {
233
- return ( AnyLayer < Input , Output , Scalar > ( base) , { dbase in AnyLayerTangentVector < Scalar > ( dbase) } )
229
+ value: AnyLayer , differential: ( T . TangentVector ) -> AnyLayerTangentVector
230
+ ) where T. Input == Input , T. Output == Output , T. TangentVector. VectorSpaceScalar == Float {
231
+ return ( AnyLayer < Input , Output > ( base) , { dbase in AnyLayerTangentVector ( dbase) } )
234
232
}
235
233
}
236
234
237
235
extension AnyLayer : Differentiable {
238
- public typealias TangentVector = AnyLayerTangentVector < Scalar >
236
+ public typealias TangentVector = AnyLayerTangentVector
239
237
240
238
public mutating func move( along direction: TangentVector ) {
241
239
if !isKnownUniquelyReferenced( & box) { // preserve value semantics
@@ -260,7 +258,7 @@ extension AnyLayer: Layer {
260
258
261
259
@derivative ( of: _callAsFunction)
262
260
func _vjpCallAsFunction( _ input: Input ) ->
263
- ( value: Output , pullback: ( Output . TangentVector ) -> ( AnyLayerTangentVector < Scalar > , Input . TangentVector ) ) {
261
+ ( value: Output , pullback: ( Output . TangentVector ) -> ( AnyLayerTangentVector , Input . TangentVector ) ) {
264
262
return box. _vjpCallAsFunction ( input)
265
263
}
266
264
@@ -269,4 +267,3 @@ extension AnyLayer: Layer {
269
267
return _callAsFunction ( input)
270
268
}
271
269
}
272
- #endif
0 commit comments