Skip to content

Commit 0da09d0

Browse files
committed
Minor updates to @differentiable.
swiftlang/swift#30001 removed the logic from the parser to handle the `jvp:` and `vjp:` arguments of the attribute (but left the syntax definitions in place for the time being). Since this causes an attribute with those arguments to fail to parse, I've removed them from the tests so that those tests continue to pass under the new behavior. (The arguments were always optional, so they pass under the old behavior as well.) This also caught and fixed a bug where an attribute with only a `wrt:` and a `where` clause wasn't getting formatted correctly.
1 parent 513a2c7 commit 0da09d0

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift

+13-2
Original file line numberDiff line numberDiff line change
@@ -2035,10 +2035,18 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20352035

20362036
override func visit(_ node: DifferentiableAttributeArgumentsSyntax) -> SyntaxVisitorContinueKind {
20372037
// This node encapsulates the entire list of arguments in a `@differentiable(...)` attribute.
2038-
after(node.diffParamsComma, tokens: .break(.same))
2039-
20402038
var needsBreakBeforeWhereClause = false
20412039

2040+
if let diffParamsComma = node.diffParamsComma {
2041+
after(diffParamsComma, tokens: .break(.same))
2042+
} else if node.diffParams != nil {
2043+
// If there were diff params but no comma following them, then we have "wrt: foo where ..."
2044+
// and we need a break before the where clause.
2045+
needsBreakBeforeWhereClause = true
2046+
}
2047+
2048+
// TODO: These properties will likely go away in a future version since the parser no longer
2049+
// reads the `vjp:` and `jvp:` arguments to `@differentiable`.
20422050
if let vjp = node.maybeVJP {
20432051
before(vjp.firstToken, tokens: .open)
20442052
after(vjp.lastToken, tokens: .close)
@@ -2051,6 +2059,7 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20512059
after(jvp.trailingComma, tokens: .break(.same))
20522060
needsBreakBeforeWhereClause = true
20532061
}
2062+
20542063
if let whereClause = node.whereClause {
20552064
if needsBreakBeforeWhereClause {
20562065
before(whereClause.firstToken, tokens: .break(.same))
@@ -2066,6 +2075,8 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20662075
{
20672076
// This node encapsulates the `vjp:` or `jvp:` label and decl name in a `@differentiable`
20682077
// attribute.
2078+
// TODO: This node will likely go away in a future version since the parser no longer reads the
2079+
// `vjp:` and `jvp:` arguments to `@differentiable`.
20692080
after(node.colon, tokens: .break(.continue, newlines: .elective(ignoresDiscretionary: true)))
20702081
return .visitChildren
20712082
}

Tests/SwiftFormatPrettyPrintTests/DifferentiationAttributeTests.swift

+6-16
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,28 @@ final class DifferentiationAttributeTests: PrettyPrintTestCase {
22
func testDifferentiable() {
33
let input =
44
"""
5-
@differentiable(wrt: x, vjp: d where T: D)
5+
@differentiable(wrt: x where T: D)
66
func foo<T>(_ x: T) -> T {}
77
8-
@differentiable(wrt: x, vjp: deriv where T: D)
8+
@differentiable(wrt: x where T: Differentiable)
99
func foo<T>(_ x: T) -> T {}
1010
11-
@differentiable(wrt: x, vjp: derivativeFoo where T: Differentiable)
12-
func foo<T>(_ x: T) -> T {}
13-
14-
@differentiable(wrt: theVariableNamedX, vjp: derivativeFoo where T: Differentiable)
11+
@differentiable(wrt: theVariableNamedX where T: Differentiable)
1512
func foo<T>(_ theVariableNamedX: T) -> T {}
1613
"""
1714

1815
let expected =
1916
"""
20-
@differentiable(wrt: x, vjp: d where T: D)
17+
@differentiable(wrt: x where T: D)
2118
func foo<T>(_ x: T) -> T {}
2219
2320
@differentiable(
24-
wrt: x, vjp: deriv where T: D
25-
)
26-
func foo<T>(_ x: T) -> T {}
27-
28-
@differentiable(
29-
wrt: x, vjp: derivativeFoo
30-
where T: Differentiable
21+
wrt: x where T: Differentiable
3122
)
3223
func foo<T>(_ x: T) -> T {}
3324
3425
@differentiable(
35-
wrt: theVariableNamedX,
36-
vjp: derivativeFoo
26+
wrt: theVariableNamedX
3727
where T: Differentiable
3828
)
3929
func foo<T>(_ theVariableNamedX: T) -> T {}

0 commit comments

Comments
 (0)