Skip to content

Commit 428b33c

Browse files
committed
Fix @differentiable attribute formatting.
- Don't insert a space before `where` if it's the first thing in the attribute. - Apply formatting to comma-delimited parameter "tuples" following the `wrt:` clause. Fixes SR-12414.
1 parent 744ca11 commit 428b33c

File tree

3 files changed

+89
-3
lines changed

3 files changed

+89
-3
lines changed

Sources/SwiftFormatPrettyPrint/TokenStreamCreator.swift

+25-3
Original file line numberDiff line numberDiff line change
@@ -2020,16 +2020,27 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20202020

20212021
override func visit(_ node: DifferentiableAttributeArgumentsSyntax) -> SyntaxVisitorContinueKind {
20222022
// This node encapsulates the entire list of arguments in a `@differentiable(...)` attribute.
2023+
after(node.diffParamsComma, tokens: .break(.same))
2024+
2025+
var needsBreakBeforeWhereClause = false
2026+
20232027
if let vjp = node.maybeVJP {
2024-
before(vjp.firstToken, tokens: .break(.same), .open)
2028+
before(vjp.firstToken, tokens: .open)
20252029
after(vjp.lastToken, tokens: .close)
2030+
after(vjp.trailingComma, tokens: .break(.same))
2031+
needsBreakBeforeWhereClause = true
20262032
}
20272033
if let jvp = node.maybeJVP {
2028-
before(jvp.firstToken, tokens: .break(.same), .open)
2034+
before(jvp.firstToken, tokens: .open)
20292035
after(jvp.lastToken, tokens: .close)
2036+
after(jvp.trailingComma, tokens: .break(.same))
2037+
needsBreakBeforeWhereClause = true
20302038
}
20312039
if let whereClause = node.whereClause {
2032-
before(whereClause.firstToken, tokens: .break(.same), .open)
2040+
if needsBreakBeforeWhereClause {
2041+
before(whereClause.firstToken, tokens: .break(.same))
2042+
}
2043+
before(whereClause.firstToken, tokens: .open)
20332044
after(whereClause.lastToken, tokens: .close)
20342045
}
20352046
return .visitChildren
@@ -2044,6 +2055,17 @@ fileprivate final class TokenStreamCreator: SyntaxVisitor {
20442055
return .visitChildren
20452056
}
20462057

2058+
override func visit(_ node: DifferentiationParamsSyntax) -> SyntaxVisitorContinueKind {
2059+
after(node.leftParen, tokens: .break(.open, size: 0), .open)
2060+
before(node.rightParen, tokens: .break(.close, size: 0), .close)
2061+
return .visitChildren
2062+
}
2063+
2064+
override func visit(_ node: DifferentiationParamSyntax) -> SyntaxVisitorContinueKind {
2065+
after(node.trailingComma, tokens: .break(.same))
2066+
return .visitChildren
2067+
}
2068+
20472069
// `DerivativeRegistrationAttributeArguments` was added after the Swift 5.2 release was cut.
20482070
#if HAS_DERIVATIVE_REGISTRATION_ATTRIBUTE
20492071
override func visit(_ node: DerivativeRegistrationAttributeArgumentsSyntax)

Tests/SwiftFormatPrettyPrintTests/DifferentiationAttributeTests.swift

+62
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,68 @@ final class DifferentiationAttributeTests: PrettyPrintTestCase {
4343
assertPrettyPrintEqual(input: input, expected: expected, linelength: 43)
4444
}
4545

46+
func testDifferentiableWithOnlyWhereClause() {
47+
let input =
48+
"""
49+
@differentiable(where T: D)
50+
func foo<T>(_ x: T) -> T {}
51+
52+
@differentiable(where T: Differentiable)
53+
func foo<T>(_ x: T) -> T {}
54+
"""
55+
56+
let expected =
57+
"""
58+
@differentiable(where T: D)
59+
func foo<T>(_ x: T) -> T {}
60+
61+
@differentiable(
62+
where T: Differentiable
63+
)
64+
func foo<T>(_ x: T) -> T {}
65+
66+
"""
67+
68+
assertPrettyPrintEqual(input: input, expected: expected, linelength: 28)
69+
}
70+
71+
func testDifferentiableWithMultipleParameters() {
72+
let input =
73+
"""
74+
@differentiable(wrt: (x, y))
75+
func foo<T>(_ x: T) -> T {}
76+
77+
@differentiable(wrt: (self, x, y))
78+
func foo<T>(_ x: T) -> T {}
79+
80+
@differentiable(wrt: (theVariableNamedSelf, theVariableNamedX, theVariableNamedY))
81+
func foo<T>(_ x: T) -> T {}
82+
"""
83+
84+
let expected =
85+
"""
86+
@differentiable(wrt: (x, y))
87+
func foo<T>(_ x: T) -> T {}
88+
89+
@differentiable(
90+
wrt: (self, x, y)
91+
)
92+
func foo<T>(_ x: T) -> T {}
93+
94+
@differentiable(
95+
wrt: (
96+
theVariableNamedSelf,
97+
theVariableNamedX,
98+
theVariableNamedY
99+
)
100+
)
101+
func foo<T>(_ x: T) -> T {}
102+
103+
"""
104+
105+
assertPrettyPrintEqual(input: input, expected: expected, linelength: 28)
106+
}
107+
46108
func testDerivative() {
47109
#if HAS_DERIVATIVE_REGISTRATION_ATTRIBUTE
48110
let input =

Tests/SwiftFormatPrettyPrintTests/XCTestManifests.swift

+2
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ extension DifferentiationAttributeTests {
209209
static let __allTests__DifferentiationAttributeTests = [
210210
("testDerivative", testDerivative),
211211
("testDifferentiable", testDifferentiable),
212+
("testDifferentiableWithMultipleParameters", testDifferentiableWithMultipleParameters),
213+
("testDifferentiableWithOnlyWhereClause", testDifferentiableWithOnlyWhereClause),
212214
("testTranspose", testTranspose),
213215
]
214216
}

0 commit comments

Comments
 (0)