Skip to content

Commit 0585eb0

Browse files
authored
[AutoDiff upstream] Add @derivative(of:) attribute. (#28321)
The `@derivative(of:)` attribute registers a function as a derivative of another function. This patch adds the `@derivative(of:)` attribute definition, syntax, parsing, and printing. Resolves TF-826. Todos: - Type-checking (TF-829). - Serialization (TF-837).
1 parent 678a707 commit 0585eb0

19 files changed

+588
-50
lines changed

include/swift/AST/Attr.def

+5
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,11 @@ DECL_ATTR(_originallyDefinedIn, OriginallyDefinedIn,
529529
ABIBreakingToAdd | ABIBreakingToRemove | APIStableToAdd | APIStableToRemove,
530530
96)
531531

532+
DECL_ATTR(derivative, Derivative,
533+
OnFunc | LongAttribute | AllowMultipleAttributes |
534+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
535+
97)
536+
532537
#undef TYPE_ATTR
533538
#undef DECL_ATTR_ALIAS
534539
#undef CONTEXTUAL_DECL_ATTR_ALIAS

include/swift/AST/Attr.h

+78
Original file line numberDiff line numberDiff line change
@@ -1747,6 +1747,84 @@ class DifferentiableAttr final
17471747
}
17481748
};
17491749

1750+
/// Attribute that registers a function as a derivative of another function.
1751+
///
1752+
/// Examples:
1753+
/// @derivative(of: sin(_:))
1754+
/// @derivative(of: +, wrt: (lhs, rhs))
1755+
class DerivativeAttr final
1756+
: public DeclAttribute,
1757+
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
1758+
friend TrailingObjects;
1759+
1760+
/// The original function name.
1761+
DeclNameWithLoc OriginalFunctionName;
1762+
/// The original function declaration, resolved by the type checker.
1763+
AbstractFunctionDecl *OriginalFunction = nullptr;
1764+
/// The number of parsed parameters specified in 'wrt:'.
1765+
unsigned NumParsedParameters = 0;
1766+
/// The differentiation parameters' indices, resolved by the type checker.
1767+
IndexSubset *ParameterIndices = nullptr;
1768+
/// The derivative function kind (JVP or VJP), resolved by the type checker.
1769+
Optional<AutoDiffDerivativeFunctionKind> Kind = None;
1770+
1771+
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1772+
DeclNameWithLoc original,
1773+
ArrayRef<ParsedAutoDiffParameter> params);
1774+
1775+
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1776+
DeclNameWithLoc original, IndexSubset *indices);
1777+
1778+
public:
1779+
static DerivativeAttr *create(ASTContext &context, bool implicit,
1780+
SourceLoc atLoc, SourceRange baseRange,
1781+
DeclNameWithLoc original,
1782+
ArrayRef<ParsedAutoDiffParameter> params);
1783+
1784+
static DerivativeAttr *create(ASTContext &context, bool implicit,
1785+
SourceLoc atLoc, SourceRange baseRange,
1786+
DeclNameWithLoc original, IndexSubset *indices);
1787+
1788+
DeclNameWithLoc getOriginalFunctionName() const {
1789+
return OriginalFunctionName;
1790+
}
1791+
AbstractFunctionDecl *getOriginalFunction() const {
1792+
return OriginalFunction;
1793+
}
1794+
void setOriginalFunction(AbstractFunctionDecl *decl) {
1795+
OriginalFunction = decl;
1796+
}
1797+
1798+
AutoDiffDerivativeFunctionKind getDerivativeKind() const {
1799+
assert(Kind && "Derivative function kind has not yet been resolved");
1800+
return *Kind;
1801+
}
1802+
void setDerivativeKind(AutoDiffDerivativeFunctionKind kind) { Kind = kind; }
1803+
1804+
/// The parsed differentiation parameters, i.e. the list of parameters
1805+
/// specified in 'wrt:'.
1806+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1807+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1808+
}
1809+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1810+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1811+
}
1812+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1813+
return NumParsedParameters;
1814+
}
1815+
1816+
IndexSubset *getParameterIndices() const {
1817+
return ParameterIndices;
1818+
}
1819+
void setParameterIndices(IndexSubset *parameterIndices) {
1820+
ParameterIndices = parameterIndices;
1821+
}
1822+
1823+
static bool classof(const DeclAttribute *DA) {
1824+
return DA->getKind() == DAK_Derivative;
1825+
}
1826+
};
1827+
17501828
/// Attributes that may be applied to declarations.
17511829
class DeclAttributes {
17521830
/// Linked list of declaration attributes.

include/swift/AST/AutoDiff.h

+41-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,47 @@
2626

2727
namespace swift {
2828

29+
/// A function type differentiability kind.
30+
enum class DifferentiabilityKind : uint8_t {
31+
NonDifferentiable = 0,
32+
Normal = 1,
33+
Linear = 2
34+
};
35+
36+
/// The kind of an linear map.
37+
struct AutoDiffLinearMapKind {
38+
enum innerty : uint8_t {
39+
// The differential function.
40+
Differential = 0,
41+
// The pullback function.
42+
Pullback = 1
43+
} rawValue;
44+
45+
AutoDiffLinearMapKind() = default;
46+
AutoDiffLinearMapKind(innerty rawValue) : rawValue(rawValue) {}
47+
operator innerty() const { return rawValue; }
48+
};
49+
50+
/// The kind of a derivative function.
51+
struct AutoDiffDerivativeFunctionKind {
52+
enum innerty : uint8_t {
53+
// The Jacobian-vector products function.
54+
JVP = 0,
55+
// The vector-Jacobian products function.
56+
VJP = 1
57+
} rawValue;
58+
59+
AutoDiffDerivativeFunctionKind() = default;
60+
AutoDiffDerivativeFunctionKind(innerty rawValue) : rawValue(rawValue) {}
61+
AutoDiffDerivativeFunctionKind(AutoDiffLinearMapKind linMapKind)
62+
: rawValue(static_cast<innerty>(linMapKind.rawValue)) {}
63+
explicit AutoDiffDerivativeFunctionKind(StringRef string);
64+
operator innerty() const { return rawValue; }
65+
AutoDiffLinearMapKind getLinearMapKind() {
66+
return (AutoDiffLinearMapKind::innerty)rawValue;
67+
}
68+
};
69+
2970
class ParsedAutoDiffParameter {
3071
public:
3172
enum class Kind { Named, Ordered, Self };
@@ -89,12 +130,6 @@ class ParsedAutoDiffParameter {
89130
}
90131
};
91132

92-
enum class DifferentiabilityKind : uint8_t {
93-
NonDifferentiable = 0,
94-
Normal = 1,
95-
Linear = 2
96-
};
97-
98133
} // end namespace swift
99134

100135
#endif // SWIFT_AST_AUTODIFF_H

include/swift/AST/DiagnosticsParse.def

+9
Original file line numberDiff line numberDiff line change
@@ -1329,6 +1329,11 @@ ERROR(attr_expected_comma,none,
13291329
ERROR(attr_expected_string_literal,none,
13301330
"expected string literal in '%0' attribute", (StringRef))
13311331

1332+
ERROR(attr_missing_label,PointsToFirstBadToken,
1333+
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
1334+
ERROR(attr_expected_label,none,
1335+
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
1336+
13321337
ERROR(alignment_must_be_positive_integer,none,
13331338
"alignment value must be a positive integer literal", ())
13341339

@@ -1550,6 +1555,10 @@ ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15501555
"expected a parameter, which can be a function parameter name, "
15511556
"parameter index, or 'self'", ())
15521557

1558+
// derivative
1559+
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
1560+
"expected an original function name", ())
1561+
15531562
//------------------------------------------------------------------------------
15541563
// MARK: Generics parsing diagnostics
15551564
//------------------------------------------------------------------------------

include/swift/Parse/Parser.h

+4
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,10 @@ class Parser {
997997
bool parseDifferentiationParametersClause(
998998
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
999999

1000+
/// Parse the @derivative attribute.
1001+
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
1002+
SourceLoc Loc);
1003+
10001004
/// Parse a specific attribute.
10011005
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);
10021006

lib/AST/Attr.cpp

+42-2
Original file line numberDiff line numberDiff line change
@@ -1027,10 +1027,12 @@ StringRef DeclAttribute::getAttrName() const {
10271027
return "<<custom>>";
10281028
case DAK_ProjectedValueProperty:
10291029
return "_projectedValueProperty";
1030-
case DAK_Differentiable:
1031-
return "differentiable";
10321030
case DAK_OriginallyDefinedIn:
10331031
return "_originallyDefinedIn";
1032+
case DAK_Differentiable:
1033+
return "differentiable";
1034+
case DAK_Derivative:
1035+
return "derivative";
10341036
}
10351037
llvm_unreachable("bad DeclAttrKind");
10361038
}
@@ -1450,6 +1452,44 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
14501452
omitAssociatedFunctions);
14511453
}
14521454

1455+
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1456+
SourceRange baseRange,
1457+
DeclNameWithLoc originalName,
1458+
ArrayRef<ParsedAutoDiffParameter> params)
1459+
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
1460+
OriginalFunctionName(std::move(originalName)),
1461+
NumParsedParameters(params.size()) {
1462+
std::copy(params.begin(), params.end(),
1463+
getTrailingObjects<ParsedAutoDiffParameter>());
1464+
}
1465+
1466+
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1467+
SourceRange baseRange,
1468+
DeclNameWithLoc originalName,
1469+
IndexSubset *indices)
1470+
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
1471+
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
1472+
}
1473+
1474+
DerivativeAttr *
1475+
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1476+
SourceRange baseRange, DeclNameWithLoc originalName,
1477+
ArrayRef<ParsedAutoDiffParameter> params) {
1478+
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1479+
void *mem = context.Allocate(size, alignof(DerivativeAttr));
1480+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1481+
std::move(originalName), params);
1482+
}
1483+
1484+
DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
1485+
SourceLoc atLoc, SourceRange baseRange,
1486+
DeclNameWithLoc originalName,
1487+
IndexSubset *indices) {
1488+
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
1489+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1490+
std::move(originalName), indices);
1491+
}
1492+
14531493
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
14541494
TypeLoc ProtocolType,
14551495
DeclName MemberName,

0 commit comments

Comments
 (0)