Skip to content

Commit 58a6e7a

Browse files
authored
[AutoDiff] NFC: Gardening. (#28673)
Minor style fixes and gardening. Change code to be closer to master after #28321.
1 parent f3602bb commit 58a6e7a

File tree

8 files changed

+118
-83
lines changed

8 files changed

+118
-83
lines changed

include/swift/AST/Attr.h

+4-8
Original file line numberDiff line numberDiff line change
@@ -1940,10 +1940,6 @@ class DifferentiableAttr final
19401940
FuncDecl *getVJPFunction() const { return VJPFunction; }
19411941
void setVJPFunction(FuncDecl *decl);
19421942

1943-
bool parametersMatch(const DifferentiableAttr &other) const {
1944-
return getParameterIndices() == other.getParameterIndices();
1945-
}
1946-
19471943
/// Get the derivative generic environment for the given `@differentiable`
19481944
/// attribute and original function.
19491945
GenericEnvironment *
@@ -2031,8 +2027,8 @@ class DerivativeAttr final
20312027
IndexSubset *getParameterIndices() const {
20322028
return ParameterIndices;
20332029
}
2034-
void setParameterIndices(IndexSubset *pi) {
2035-
ParameterIndices = pi;
2030+
void setParameterIndices(IndexSubset *parameterIndices) {
2031+
ParameterIndices = parameterIndices;
20362032
}
20372033

20382034
static bool classof(const DeclAttribute *DA) {
@@ -2111,8 +2107,8 @@ class TransposeAttr final
21112107
IndexSubset *getParameterIndices() const {
21122108
return ParameterIndices;
21132109
}
2114-
void setParameterIndices(IndexSubset *pi) {
2115-
ParameterIndices = pi;
2110+
void setParameterIndices(IndexSubset *parameterIndices) {
2111+
ParameterIndices = parameterIndices;
21162112
}
21172113

21182114
static bool classof(const DeclAttribute *DA) {

include/swift/AST/AutoDiff.h

+40-40
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,46 @@
2626

2727
namespace swift {
2828

29+
enum class DifferentiabilityKind : uint8_t {
30+
NonDifferentiable = 0,
31+
Normal = 1,
32+
Linear = 2
33+
};
34+
35+
/// The kind of an linear map.
36+
struct AutoDiffLinearMapKind {
37+
enum innerty : uint8_t {
38+
// The differential function.
39+
Differential = 0,
40+
// The pullback function.
41+
Pullback = 1
42+
} rawValue;
43+
44+
AutoDiffLinearMapKind() = default;
45+
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
46+
operator innerty() const { return rawValue; }
47+
};
48+
49+
/// The kind of a derivative function.
50+
struct AutoDiffDerivativeFunctionKind {
51+
enum innerty : uint8_t {
52+
// The Jacobian-vector products function.
53+
JVP = 0,
54+
// The vector-Jacobian products function.
55+
VJP = 1
56+
} rawValue;
57+
58+
AutoDiffDerivativeFunctionKind() = default;
59+
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
60+
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
61+
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
62+
explicit AutoDiffDerivativeFunctionKind(StringRef string);
63+
operator innerty() const { return rawValue; }
64+
AutoDiffLinearMapKind getLinearMapKind() {
65+
return (AutoDiffLinearMapKind::innerty)rawValue;
66+
}
67+
};
68+
2969
class ParsedAutoDiffParameter {
3070
public:
3171
enum class Kind { Named, Ordered, Self };
@@ -89,12 +129,6 @@ class ParsedAutoDiffParameter {
89129
}
90130
};
91131

92-
enum class DifferentiabilityKind : uint8_t {
93-
NonDifferentiable = 0,
94-
Normal = 1,
95-
Linear = 2
96-
};
97-
98132
} // end namespace swift
99133

100134
// SWIFT_ENABLE_TENSORFLOW
@@ -120,40 +154,6 @@ class SILFunctionType;
120154
typedef CanTypeWrapper<SILFunctionType> CanSILFunctionType;
121155
enum class SILLinkage : uint8_t;
122156

123-
/// The kind of an linear map.
124-
struct AutoDiffLinearMapKind {
125-
enum innerty : uint8_t {
126-
// The differential function.
127-
Differential = 0,
128-
// The pullback function.
129-
Pullback = 1
130-
} rawValue;
131-
132-
AutoDiffLinearMapKind() = default;
133-
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
134-
operator innerty() const { return rawValue; }
135-
};
136-
137-
/// The kind of a derivative function.
138-
struct AutoDiffDerivativeFunctionKind {
139-
enum innerty : uint8_t {
140-
// The Jacobian-vector products function.
141-
JVP = 0,
142-
// The vector-Jacobian products function.
143-
VJP = 1
144-
} rawValue;
145-
146-
AutoDiffDerivativeFunctionKind() = default;
147-
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
148-
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
149-
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
150-
explicit AutoDiffDerivativeFunctionKind(StringRef string);
151-
operator innerty() const { return rawValue; }
152-
AutoDiffLinearMapKind getLinearMapKind() {
153-
return (AutoDiffLinearMapKind::innerty)rawValue;
154-
}
155-
};
156-
157157
/// The kind of a differentiability witness function.
158158
struct DifferentiabilityWitnessFunctionKind {
159159
enum innerty : uint8_t {

include/swift/AST/DiagnosticsParse.def

+5-4
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,11 @@ ERROR(attr_expected_comma,none,
13641364
ERROR(attr_expected_string_literal,none,
13651365
"expected string literal in '%0' attribute", (StringRef))
13661366

1367+
ERROR(attr_missing_label,PointsToFirstBadToken,
1368+
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
1369+
ERROR(attr_expected_label,none,
1370+
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
1371+
13671372
ERROR(alignment_must_be_positive_integer,none,
13681373
"alignment value must be a positive integer literal", ())
13691374

@@ -1579,10 +1584,6 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15791584
// derivative
15801585
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
15811586
"expected an original function name", ())
1582-
ERROR(attr_missing_label,PointsToFirstBadToken,
1583-
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
1584-
ERROR(attr_expected_label,none,
1585-
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
15861587
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
15871588
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
15881589
"instead", ())

lib/Parse/ParseDecl.cpp

+46-2
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,18 @@ Parser::parseImplementsAttribute(SourceLoc AtLoc, SourceLoc Loc) {
804804
ProtocolType.get(), MemberName, MemberNameLoc));
805805
}
806806

807+
/// Parse a `@differentiable` attribute, returning true on error.
808+
///
809+
/// \verbatim
810+
/// differentiable-attribute-arguments:
811+
/// '(' (differentiation-params-clause ',')?
812+
/// (differentiable-attr-func-specifier ',')?
813+
/// differentiable-attr-func-specifier?
814+
/// where-clause?
815+
/// ')'
816+
/// differentiable-attr-func-specifier:
817+
/// ('jvp' | 'vjp') ':' decl-name
818+
/// \endverbatim
807819
ParserResult<DifferentiableAttr>
808820
Parser::parseDifferentiableAttribute(SourceLoc atLoc, SourceLoc loc) {
809821
StringRef AttrName = "differentiable";
@@ -852,6 +864,16 @@ static bool errorAndSkipUntilConsumeRightParen(Parser &P, StringRef attrName,
852864
return true;
853865
};
854866

867+
/// Parse a differentiation parameters 'wrt:' clause, returning true on error.
868+
///
869+
/// \verbatim
870+
/// differentiation-params-clause:
871+
/// 'wrt' ':' (differentiation-param | differentiation-params)
872+
/// differentiation-params:
873+
/// '(' differentiation-param (',' differentiation-param)* ')'
874+
/// differentiation-param:
875+
/// 'self' | identifier
876+
/// \endverbatim
855877
bool Parser::parseDifferentiationParametersClause(
856878
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {
857879
SyntaxParsingContext DiffParamsClauseContext(
@@ -929,6 +951,16 @@ bool Parser::parseDifferentiationParametersClause(
929951
}
930952

931953
// SWIFT_ENABLE_TENSORFLOW
954+
/// Parse a transposed parameters 'wrt:' clause, returning true on error.
955+
///
956+
/// \verbatim
957+
/// transposed-params-clause:
958+
/// 'wrt' ':' (transposed-param | transposed-params)
959+
/// transposed-params:
960+
/// '(' transposed-param (',' transposed-param)* ')'
961+
/// transposed-param:
962+
/// 'self' | [0-9]+
963+
/// \endverbatim
932964
bool Parser::parseTransposedParametersClause(
933965
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName) {
934966
SyntaxParsingContext TransposeParamsClauseContext(
@@ -1130,7 +1162,13 @@ bool Parser::parseDifferentiableAttributeArguments(
11301162
return false;
11311163
}
11321164

1133-
/// SWIFT_ENABLE_TENSORFLOW
1165+
// SWIFT_ENABLE_TENSORFLOW
1166+
/// Parse a `@derivative(of:)` attribute, returning true on error.
1167+
///
1168+
/// \verbatim
1169+
/// derivative-attribute-arguments:
1170+
/// '(' 'of' ':' decl-name (',' differentiation-params-clause)? ')'
1171+
/// \endverbatim
11341172
ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
11351173
SourceLoc loc) {
11361174
StringRef AttrName = "derivative";
@@ -1196,7 +1234,7 @@ ParserResult<DerivativeAttr> Parser::parseDerivativeAttribute(SourceLoc atLoc,
11961234
SourceRange(loc, rParenLoc), original, params));
11971235
}
11981236

1199-
/// SWIFT_ENABLE_TENSORFLOW
1237+
// SWIFT_ENABLE_TENSORFLOW
12001238
ParserResult<DerivativeAttr>
12011239
Parser::parseDifferentiatingAttribute(SourceLoc atLoc, SourceLoc loc) {
12021240
StringRef AttrName = "differentiating";
@@ -1315,6 +1353,12 @@ bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
13151353
// SWIFT_ENABLE_TENSORFLOW END
13161354

13171355
// SWIFT_ENABLE_TENSORFLOW
1356+
/// Parse a `@transpose(of:)` attribute, returning true on error.
1357+
///
1358+
/// \verbatim
1359+
/// transpose-attribute-arguments:
1360+
/// '(' 'of' ':' decl-name (',' transposed-params-clause)? ')'
1361+
/// \endverbatim
13181362
ParserResult<TransposeAttr> Parser::parseTransposeAttribute(SourceLoc atLoc,
13191363
SourceLoc loc) {
13201364
StringRef AttrName = "transpose";

lib/Serialization/ModuleFormat.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ using DifferentiabilityKindField = BCFixed<2>;
237237
// module version.
238238
enum class AutoDiffDerivativeFunctionKind : uint8_t {
239239
JVP = 0,
240-
VJP = 1
240+
VJP
241241
};
242242
using AutoDiffDerivativeFunctionKindField = BCFixed<1>;
243243
// SWIFT_ENABLE_TENSORFLOW END

lib/Serialization/Serialization.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -2029,8 +2029,8 @@ static uint8_t getRawStableVarDeclIntroducer(swift::VarDecl::Introducer intr) {
20292029
}
20302030

20312031
// SWIFT_ENABLE_TENSORFLOW
2032-
/// Translate from the AST differentiability kind enum to the Serialization enum
2033-
/// values, which are guaranteed to be stable.
2032+
/// Translate from the AST derivative function kind enum to the Serialization
2033+
/// enum values, which are guaranteed to be stable.
20342034
static uint8_t getRawStableAutoDiffDerivativeFunctionKind(
20352035
swift::AutoDiffDerivativeFunctionKind kind) {
20362036
switch (kind) {

lib/TBDGen/TBDGen.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ void TBDGenVisitor::addAutoDiffLinearMapFunction(AbstractFunctionDecl *original,
183183
AutoDiffLinearMapKind kind) {
184184
auto declRef = SILDeclRef(original);
185185

186-
// Linear maps are only public when the original function is serialized.
186+
// Linear maps are public only when the original function is serialized.
187187
if (!declRef.isSerialized())
188188
return;
189189

190-
// Differentials are only emitted when forward mode is turned on.
190+
// Linear maps are emitted only when forward mode is enabled.
191191
if (kind == AutoDiffLinearMapKind::Differential &&
192192
!original->getASTContext()
193193
.LangOpts.EnableExperimentalForwardModeDifferentiation)

test/AutoDiff/derivative_attr_parse.swift

+18-24
Original file line numberDiff line numberDiff line change
@@ -2,67 +2,61 @@
22

33
/// Good
44

5-
@derivative(of: sin) // ok
6-
func jvpSin(x: @nondiff Float)
7-
-> (value: Float, differential: (Float)-> (Float)) {
8-
return (x, { $0 })
9-
}
10-
115
@derivative(of: sin, wrt: x) // ok
126
func vjpSin(x: Float) -> (value: Float, pullback: (Float) -> Float) {
137
return (x, { $0 })
148
}
159

1610
@derivative(of: add, wrt: (x, y)) // ok
1711
func vjpAdd(x: Float, y: Float)
18-
-> (value: Float, pullback: (Float) -> (Float, Float)) {
12+
-> (value: Float, pullback: (Float) -> (Float, Float)) {
1913
return (x + y, { ($0, $0) })
2014
}
2115

22-
extension AdditiveArithmetic where Self : Differentiable {
16+
extension AdditiveArithmetic where Self: Differentiable {
2317
@derivative(of: +) // ok
24-
static func vjpPlus(x: Self, y: Self) -> (value: Self,
25-
pullback: (Self.TangentVector) -> (Self.TangentVector, Self.TangentVector)) {
18+
static func vjpAdd(x: Self, y: Self)
19+
-> (value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector)) {
2620
return (x + y, { v in (v, v) })
2721
}
2822
}
2923

30-
@derivative(of: linear) // ok
24+
@derivative(of: foo) // ok
3125
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
3226
return (x, { $0 })
3327
}
3428

35-
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
36-
@derivative(of: linear, linear) // ok
29+
/// Bad
30+
31+
// expected-error @+3 {{expected an original function name}}
32+
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
33+
// expected-error @+1 {{expected declaration}}
34+
@derivative(of: 3)
3735
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
3836
return (x, { $0 })
3937
}
4038

4139
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
42-
@derivative(of: foo, linear, wrt: x) // ok
40+
@derivative(of: wrt, foo)
4341
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
4442
return (x, { $0 })
4543
}
4644

47-
/// Bad
48-
49-
// expected-error @+3 {{expected an original function name}}
50-
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
51-
// expected-error @+1 {{expected declaration}}
52-
@derivative(of: 3)
45+
// expected-error @+1 {{expected a colon ':' after 'wrt'}}
46+
@derivative(of: foo, wrt)
5347
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
5448
return (x, { $0 })
5549
}
5650

5751
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
58-
@derivative(of: linear, foo)
52+
@derivative(of: foo, blah, wrt: x)
5953
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
6054
return (x, { $0 })
6155
}
6256

6357
// expected-error @+2 {{expected ')' in 'derivative' attribute}}
6458
// expected-error @+1 {{expected declaration}}
65-
@derivative(of: foo, wrt: x, linear)
59+
@derivative(of: foo, wrt: x, blah)
6660
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
6761
return (x, { $0 })
6862
}
@@ -81,13 +75,13 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
8175
}
8276

8377
// expected-error @+1 {{expected label 'wrt:' in '@derivative' attribute}}
84-
@derivative(of: linear, foo,)
78+
@derivative(of: foo, foo,)
8579
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
8680
return (x, { $0 })
8781
}
8882

8983
// expected-error @+1 {{unexpected ',' separator}}
90-
@derivative(of: linear,)
84+
@derivative(of: foo,)
9185
func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
9286
return (x, { $0 })
9387
}

0 commit comments

Comments
 (0)