Skip to content

Commit 6e11994

Browse files
committed
[AutoDiff] Rename @differentiating to @derivative(of:).
Rename `@differentiating` to `@derivative(of:)`. `@derivative(of:)` more clearly evokes derivative registration; the syntax is otherwise unchanged. Deprecate `@differentiating`, to be removed in the next release. Discussed here: swiftlang#28321 (comment) Partially resolves TF-999. TF-1000 tracks updating all `@differentiating` usages across repositories.
1 parent 658b7f7 commit 6e11994

34 files changed

+424
-274
lines changed

include/swift/AST/Attr.def

+6-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ DECL_ATTR(differentiable, Differentiable,
512512
AllowMultipleAttributes |
513513
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
514514
91)
515-
DECL_ATTR(differentiating, Differentiating,
515+
DECL_ATTR(derivative, Derivative,
516516
OnFunc | LongAttribute | AllowMultipleAttributes |
517517
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
518518
NotSerialized, 92)
@@ -540,6 +540,11 @@ DECL_ATTR(quoted, Quoted,
540540
OnFunc |
541541
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
542542
97)
543+
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
544+
DECL_ATTR(differentiating, Differentiating,
545+
OnFunc | LongAttribute | AllowMultipleAttributes |
546+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
547+
NotSerialized, 98)
543548
// SWIFT_ENABLE_TENSORFLOW END
544549

545550
#undef TYPE_ATTR

include/swift/AST/Attr.h

+21-21
Original file line numberDiff line numberDiff line change
@@ -1670,12 +1670,11 @@ class DifferentiableAttr final
16701670
/// Attribute that registers a function as a derivative of another function.
16711671
///
16721672
/// Examples:
1673-
/// @differentiating(sin(_:))
1674-
/// @differentiating(+, wrt: (lhs, rhs))
1675-
class DifferentiatingAttr final
1673+
/// @derivative(of: sin(_:))
1674+
/// @derivative(of: +, wrt: (lhs, rhs))
1675+
class DerivativeAttr final
16761676
: public DeclAttribute,
1677-
private llvm::TrailingObjects<DifferentiatingAttr,
1678-
ParsedAutoDiffParameter> {
1677+
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
16791678
friend TrailingObjects;
16801679

16811680
/// The original function name.
@@ -1687,24 +1686,22 @@ class DifferentiatingAttr final
16871686
/// The differentiation parameters' indices, resolved by the type checker.
16881687
IndexSubset *ParameterIndices = nullptr;
16891688

1690-
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1691-
SourceRange baseRange, DeclNameWithLoc original,
1692-
ArrayRef<ParsedAutoDiffParameter> params);
1689+
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1690+
DeclNameWithLoc original,
1691+
ArrayRef<ParsedAutoDiffParameter> params);
16931692

1694-
explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1695-
SourceRange baseRange, DeclNameWithLoc original,
1696-
IndexSubset *indices);
1693+
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1694+
DeclNameWithLoc original, IndexSubset *indices);
16971695

16981696
public:
1699-
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
1700-
SourceLoc atLoc, SourceRange baseRange,
1701-
DeclNameWithLoc original,
1702-
ArrayRef<ParsedAutoDiffParameter> params);
1697+
static DerivativeAttr *create(ASTContext &context, bool implicit,
1698+
SourceLoc atLoc, SourceRange baseRange,
1699+
DeclNameWithLoc original,
1700+
ArrayRef<ParsedAutoDiffParameter> params);
17031701

1704-
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
1705-
SourceLoc atLoc, SourceRange baseRange,
1706-
DeclNameWithLoc original,
1707-
IndexSubset *indices);
1702+
static DerivativeAttr *create(ASTContext &context, bool implicit,
1703+
SourceLoc atLoc, SourceRange baseRange,
1704+
DeclNameWithLoc original, IndexSubset *indices);
17081705

17091706
DeclNameWithLoc getOriginalFunctionName() const {
17101707
return OriginalFunctionName;
@@ -1736,10 +1733,13 @@ class DifferentiatingAttr final
17361733
}
17371734

17381735
static bool classof(const DeclAttribute *DA) {
1739-
return DA->getKind() == DAK_Differentiating;
1736+
return DA->getKind() == DAK_Derivative;
17401737
}
17411738
};
1742-
1739+
1740+
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
1741+
using DifferentiatingAttr = DerivativeAttr;
1742+
17431743
/// Attribute that registers a function as a transpose of another function.
17441744
///
17451745
/// Examples:

include/swift/AST/DiagnosticsParse.def

+5-2
Original file line numberDiff line numberDiff line change
@@ -1561,13 +1561,16 @@ ERROR(attr_differentiable_expected_label,none,
15611561
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
15621562
"or 'vjp:'", ())
15631563

1564-
// differentiating
1565-
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
1564+
// derivative
1565+
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
15661566
"expected an original function name", ())
15671567
ERROR(attr_missing_label,PointsToFirstBadToken,
15681568
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
15691569
ERROR(attr_expected_label,none,
15701570
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
1571+
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
1572+
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
1573+
"instead", ())
15711574

15721575
// transposing
15731576
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSema.def

+16-16
Original file line numberDiff line numberDiff line change
@@ -2995,33 +2995,33 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
29952995
NOTE(protocol_witness_missing_differentiable_attr,none,
29962996
"candidate is missing attribute '%0'", (StringRef))
29972997

2998-
// @differentiating
2999-
ERROR(differentiating_attr_expected_result_tuple,none,
3000-
"'@differentiating' attribute requires function to return a two-element tuple of type "
2998+
// @derivative
2999+
ERROR(derivative_attr_expected_result_tuple,none,
3000+
"'@derivative(of:)' attribute requires function to return a two-element tuple of type "
30013001
"'(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or "
30023002
"'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
3003-
ERROR(differentiating_attr_invalid_result_tuple_value_label,none,
3004-
"'@differentiating' attribute requires function to return a two-element "
3003+
ERROR(derivative_attr_invalid_result_tuple_value_label,none,
3004+
"'@derivative(of:)' attribute requires function to return a two-element "
30053005
"tuple (first element must have label 'value:')", ())
3006-
ERROR(differentiating_attr_invalid_result_tuple_func_label,none,
3007-
"'@differentiating' attribute requires function to return a two-element "
3006+
ERROR(derivative_attr_invalid_result_tuple_func_label,none,
3007+
"'@derivative(of:)' attribute requires function to return a two-element "
30083008
"tuple (second element must have label 'pullback:' or 'differential:')", ())
3009-
ERROR(differentiating_attr_result_value_not_differentiable,none,
3010-
"'@differentiating' attribute requires function to return a two-element "
3009+
ERROR(derivative_attr_result_value_not_differentiable,none,
3010+
"'@derivative(of:)' attribute requires function to return a two-element "
30113011
"tuple (first element type %0 must conform to 'Differentiable')", (Type))
3012-
ERROR(differentiating_attr_result_func_type_mismatch,none,
3012+
ERROR(derivative_attr_result_func_type_mismatch,none,
30133013
"function result's %0 type does not match %1", (Identifier, DeclName))
3014-
NOTE(differentiating_attr_result_func_type_mismatch_note,none,
3014+
NOTE(derivative_attr_result_func_type_mismatch_note,none,
30153015
"%0 does not have expected type %1", (Identifier, Type))
3016-
NOTE(differentiating_attr_result_func_original_note,none,
3016+
NOTE(derivative_attr_result_func_original_note,none,
30173017
"%0 defined here", (DeclName))
3018-
ERROR(differentiating_attr_overload_not_found,none,
3018+
ERROR(derivative_attr_overload_not_found,none,
30193019
"could not find function %0 with expected type %1", (DeclName, Type))
3020-
ERROR(differentiating_attr_not_in_same_file_as_original,none,
3020+
ERROR(derivative_attr_not_in_same_file_as_original,none,
30213021
"derivative not in the same file as the original function", ())
3022-
ERROR(differentiating_attr_original_stored_property_unsupported,none,
3022+
ERROR(derivative_attr_original_stored_property_unsupported,none,
30233023
"cannot register derivative for stored property %0", (DeclName))
3024-
ERROR(differentiating_attr_original_already_has_derivative,none,
3024+
ERROR(derivative_attr_original_already_has_derivative,none,
30253025
"a derivative already exists for %0", (DeclName))
30263026

30273027
// transposing

include/swift/AST/Types.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3142,7 +3142,7 @@ class AnyFunctionType : public TypeBase {
31423142
///
31433143
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
31443144
/// first. This should be used during type-checking, e.g. type-checking
3145-
/// `@differentiable`, `@differentiating`, and `@transposing` attributes.
3145+
/// `@differentiable`, `@derivative`, and `@transposing` attributes.
31463146
///
31473147
/// \note The original function type (`self`) need not be `@differentiable`.
31483148
/// The resulting function will preserve all `ExtInfo` of the original

include/swift/Parse/Parser.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -1004,9 +1004,14 @@ class Parser {
10041004
bool parseTransposingParametersClause(
10051005
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
10061006

1007-
/// Parse the @differentiating attribute.
1008-
ParserResult<DifferentiatingAttr>
1009-
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);
1007+
/// Parse the @derivative attribute.
1008+
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
1009+
SourceLoc Loc);
1010+
1011+
/// Parse the deprecated @differentiating attribute.
1012+
// TODO(TF-999): Remove the deprecated `@differentiating` attribute.
1013+
ParserResult<DerivativeAttr> parseDifferentiatingAttribute(SourceLoc AtLoc,
1014+
SourceLoc Loc);
10101015

10111016
/// Parse the @transposing attribute.
10121017
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,

include/swift/SIL/SILDifferentiabilityWitness.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// indices, derivative generic signature) to derivative functions (JVP and VJP).
1616
//
1717
// SIL differentiability witnesses are generated from the `@differentiable`
18-
// and `@differentiating` attributes AST declaration attributes.
18+
// and `@derivative` attribute AST declaration attributes.
1919
// Differentiability witnesses are canonicalized by the differentiation SIL
2020
// transform, which fills in missing derivative functions. Canonical
2121
// differentiability witnesses from other modules can be deserialized to look up
@@ -60,7 +60,7 @@ class SILDifferentiabilityWitness
6060
/// Whether or not this differentiability witness is serialized, which allows
6161
/// devirtualization from another module.
6262
bool IsSerialized;
63-
/// The AST `@differentiable` or `@differentiating` attribute from which the
63+
/// The AST `@differentiable` or `@derivative` attribute from which the
6464
/// differentiability witness is generated. Used for diagnostics.
6565
/// Null if the differentiability witness is parsed from SIL or if it is
6666
/// deserialized.

lib/AST/ASTScopeCreation.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ class ScopeCreator final {
467467
// necessary to avoid verification failure:
468468
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
469469
// Perhaps this check is no longer necessary after TF-835: robust
470-
// `@differentiating` attribute lowering.
470+
// `@derivative` attribute lowering.
471471
if (!diffAttr->isImplicit())
472472
sortedDifferentiableAttrs.push_back(diffAttr);
473473
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))

lib/AST/Attr.cpp

+33-32
Original file line numberDiff line numberDiff line change
@@ -921,10 +921,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
921921
}
922922

923923
// SWIFT_ENABLE_TENSORFLOW
924-
case DAK_Differentiating: {
925-
Printer.printAttrName("@differentiating");
926-
Printer << '(';
927-
auto *attr = cast<DifferentiatingAttr>(this);
924+
case DAK_Derivative: {
925+
Printer.printAttrName("@derivative");
926+
Printer << "(of: ";
927+
auto *attr = cast<DerivativeAttr>(this);
928928
auto *derivative = cast<AbstractFunctionDecl>(D);
929929
Printer << attr->getOriginalFunctionName().Name;
930930
auto diffParamsString = getDifferentiationParametersClauseString(
@@ -934,7 +934,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
934934
Printer << ')';
935935
break;
936936
}
937-
937+
938938
// SWIFT_ENABLE_TENSORFLOW
939939
case DAK_Transposing: {
940940
Printer.printAttrName("@transposing");
@@ -1108,6 +1108,8 @@ StringRef DeclAttribute::getAttrName() const {
11081108
// SWIFT_ENABLE_TENSORFLOW
11091109
case DAK_Differentiable:
11101110
return "differentiable";
1111+
case DAK_Derivative:
1112+
return "derivative";
11111113
case DAK_Differentiating:
11121114
return "differentiating";
11131115
case DAK_Transposing:
@@ -1568,43 +1570,42 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
15681570
}
15691571

15701572
// SWIFT_ENABLE_TENSORFLOW
1571-
DifferentiatingAttr::DifferentiatingAttr(
1572-
bool implicit, SourceLoc atLoc, SourceRange baseRange,
1573-
DeclNameWithLoc originalName, ArrayRef<ParsedAutoDiffParameter> params)
1574-
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1573+
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1574+
SourceRange baseRange,
1575+
DeclNameWithLoc originalName,
1576+
ArrayRef<ParsedAutoDiffParameter> params)
1577+
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
15751578
OriginalFunctionName(std::move(originalName)),
15761579
NumParsedParameters(params.size()) {
15771580
std::copy(params.begin(), params.end(),
15781581
getTrailingObjects<ParsedAutoDiffParameter>());
15791582
}
15801583

1581-
DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
1582-
SourceRange baseRange,
1583-
DeclNameWithLoc originalName,
1584-
IndexSubset *indices)
1585-
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1584+
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
1585+
SourceRange baseRange,
1586+
DeclNameWithLoc originalName,
1587+
IndexSubset *indices)
1588+
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
15861589
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
15871590
}
15881591

1589-
DifferentiatingAttr *
1590-
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1591-
SourceRange baseRange, DeclNameWithLoc original,
1592-
ArrayRef<ParsedAutoDiffParameter> params) {
1592+
DerivativeAttr *
1593+
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1594+
SourceRange baseRange, DeclNameWithLoc originalName,
1595+
ArrayRef<ParsedAutoDiffParameter> params) {
15931596
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1594-
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
1595-
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
1596-
std::move(original), params);
1597-
}
1598-
1599-
DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,
1600-
bool implicit, SourceLoc atLoc,
1601-
SourceRange baseRange,
1602-
DeclNameWithLoc original,
1603-
IndexSubset *indices) {
1604-
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
1605-
alignof(DifferentiatingAttr));
1606-
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
1607-
std::move(original), indices);
1597+
void *mem = context.Allocate(size, alignof(DerivativeAttr));
1598+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1599+
std::move(originalName), params);
1600+
}
1601+
1602+
DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
1603+
SourceLoc atLoc, SourceRange baseRange,
1604+
DeclNameWithLoc originalName,
1605+
IndexSubset *indices) {
1606+
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
1607+
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
1608+
std::move(originalName), indices);
16081609
}
16091610

16101611
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,

0 commit comments

Comments
 (0)