2
2
3
3
/// Good
4
4
5
- @derivative ( of: sin) // ok
6
- func jvpSin( x: @nondiff Float )
7
- -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
8
- return ( x, { $0 } )
9
- }
10
-
11
5
@derivative ( of: sin, wrt: x) // ok
12
6
func vjpSin( x: Float ) -> ( value: Float , pullback: ( Float ) -> Float ) {
13
7
return ( x, { $0 } )
14
8
}
15
9
16
10
@derivative ( of: add, wrt: ( x, y) ) // ok
17
11
func vjpAdd( x: Float , y: Float )
18
- -> ( value: Float , pullback: ( Float ) -> ( Float , Float ) ) {
12
+ -> ( value: Float , pullback: ( Float ) -> ( Float , Float ) ) {
19
13
return ( x + y, { ( $0, $0) } )
20
14
}
21
15
22
- extension AdditiveArithmetic where Self : Differentiable {
16
+ extension AdditiveArithmetic where Self: Differentiable {
23
17
@derivative ( of: + ) // ok
24
- static func vjpPlus ( x: Self , y: Self ) -> ( value : Self ,
25
- pullback: ( Self . TangentVector ) -> ( Self . TangentVector , Self . TangentVector ) ) {
18
+ static func vjpAdd ( x: Self , y: Self )
19
+ -> ( value : Self , pullback: ( TangentVector ) -> ( TangentVector , TangentVector ) ) {
26
20
return ( x + y, { v in ( v, v) } )
27
21
}
28
22
}
29
23
30
- @derivative ( of: linear ) // ok
24
+ @derivative ( of: foo ) // ok
31
25
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
32
26
return ( x, { $0 } )
33
27
}
34
28
35
- // expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
36
- @derivative ( of: linear, linear) // ok
29
+ /// Bad
30
+
31
+ // expected-error @+3 {{expected an original function name}}
32
+ // expected-error @+2 {{expected ')' in 'derivative' attribute}}
33
+ // expected-error @+1 {{expected declaration}}
34
+ @derivative ( of: 3 )
37
35
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
38
36
return ( x, { $0 } )
39
37
}
40
38
41
39
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
42
- @derivative ( of: foo , linear , wrt : x ) // ok
40
+ @derivative ( of: wrt , foo )
43
41
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
44
42
return ( x, { $0 } )
45
43
}
46
44
47
- /// Bad
48
-
49
- // expected-error @+3 {{expected an original function name}}
50
- // expected-error @+2 {{expected ')' in 'derivative' attribute}}
51
- // expected-error @+1 {{expected declaration}}
52
- @derivative ( of: 3 )
45
+ // expected-error @+1 {{expected a colon ':' after 'wrt'}}
46
+ @derivative ( of: foo, wrt)
53
47
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
54
48
return ( x, { $0 } )
55
49
}
56
50
57
51
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
58
- @derivative ( of: linear , foo )
52
+ @derivative ( of: foo , blah , wrt : x )
59
53
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
60
54
return ( x, { $0 } )
61
55
}
62
56
63
57
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
64
58
// expected-error @+1 {{expected declaration}}
65
- @derivative ( of: foo, wrt: x, linear )
59
+ @derivative ( of: foo, wrt: x, blah )
66
60
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
67
61
return ( x, { $0 } )
68
62
}
@@ -81,13 +75,13 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
81
75
}
82
76
83
77
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
84
- @derivative ( of: linear , foo, )
78
+ @derivative ( of: foo , foo, )
85
79
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
86
80
return ( x, { $0 } )
87
81
}
88
82
89
83
// expected-error @+1 {{unexpected ',' separator}}
90
- @derivative ( of: linear , )
84
+ @derivative ( of: foo , )
91
85
func dfoo( x: Float ) -> ( value: Float , differential: ( Float ) -> ( Float ) ) {
92
86
return ( x, { $0 } )
93
87
}
0 commit comments