Skip to content

Commit 3c2eeb0

Browse files
committed
[AutoDiff] Rename @transposing to @transpose(of:).
Rename `@transposing` to `@transpose(of:)` and deprecate `@transposing`. `@transpose(of:)` more clearly evokes transpose registration; the syntax is otherwise unchanged. Discussed here: swiftlang#28321 (comment) Resolves TF-992. TF-999 tracks removing `@transposing` attribute in the next release. TF-1009 tracks `@transpose` syntax support for qualified names.
1 parent 682958a commit 3c2eeb0

18 files changed

+336
-190
lines changed

include/swift/AST/Attr.def

+6-1
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
524524
OnVar |
525525
ABIBreakingToAdd | ABIBreakingToRemove | APIBreakingToAdd | APIBreakingToRemove,
526526
94)
527-
DECL_ATTR(transposing, Transposing,
527+
DECL_ATTR(transpose, Transpose,
528528
OnFunc | LongAttribute | AllowMultipleAttributes |
529529
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
530530
NotSerialized, 96)
@@ -545,6 +545,11 @@ DECL_ATTR(differentiating, Differentiating,
545545
OnFunc | LongAttribute | AllowMultipleAttributes |
546546
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
547547
NotSerialized, 98)
548+
// TODO(TF-999): Remove deprecated `@transposing` attribute.
549+
DECL_ATTR(transposing, Transposing,
550+
OnFunc | LongAttribute | AllowMultipleAttributes |
551+
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
552+
NotSerialized, 99)
548553
// SWIFT_ENABLE_TENSORFLOW END
549554

550555
#undef TYPE_ATTR

include/swift/AST/Attr.h

+23-22
Original file line numberDiff line numberDiff line change
@@ -1743,12 +1743,11 @@ using DifferentiatingAttr = DerivativeAttr;
17431743
/// Attribute that registers a function as a transpose of another function.
17441744
///
17451745
/// Examples:
1746-
/// @transposing(foo)
1747-
/// @transposing(+, wrt: (lhs, rhs))
1748-
class TransposingAttr final
1749-
: public DeclAttribute,
1750-
private llvm::TrailingObjects<TransposingAttr,
1751-
ParsedAutoDiffParameter> {
1746+
/// @transpose(of: foo)
1747+
/// @transpose(of: +, wrt: (lhs, rhs))
1748+
class TransposeAttr final
1749+
: public DeclAttribute,
1750+
private llvm::TrailingObjects<TransposeAttr, ParsedAutoDiffParameter> {
17521751
friend TrailingObjects;
17531752

17541753
/// The base type of the original function.
@@ -1764,25 +1763,24 @@ class TransposingAttr final
17641763
/// The differentiation parameters' indices, resolved by the type checker.
17651764
IndexSubset *ParameterIndices = nullptr;
17661765

1767-
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1768-
SourceRange baseRange, TypeRepr *baseType,
1769-
DeclNameWithLoc original,
1770-
ArrayRef<ParsedAutoDiffParameter> params);
1766+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1767+
TypeRepr *baseType, DeclNameWithLoc original,
1768+
ArrayRef<ParsedAutoDiffParameter> params);
17711769

1772-
explicit TransposingAttr(bool implicit, SourceLoc atLoc,
1773-
SourceRange baseRange, TypeRepr *baseType,
1774-
DeclNameWithLoc original, IndexSubset *indices);
1770+
explicit TransposeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
1771+
TypeRepr *baseType, DeclNameWithLoc original,
1772+
IndexSubset *indices);
17751773

17761774
public:
1777-
static TransposingAttr *create(ASTContext &context, bool implicit,
1778-
SourceLoc atLoc, SourceRange baseRange,
1779-
TypeRepr *baseType, DeclNameWithLoc original,
1780-
ArrayRef<ParsedAutoDiffParameter> params);
1775+
static TransposeAttr *create(ASTContext &context, bool implicit,
1776+
SourceLoc atLoc, SourceRange baseRange,
1777+
TypeRepr *baseType, DeclNameWithLoc original,
1778+
ArrayRef<ParsedAutoDiffParameter> params);
17811779

1782-
static TransposingAttr *create(ASTContext &context, bool implicit,
1783-
SourceLoc atLoc, SourceRange baseRange,
1784-
TypeRepr *baseType, DeclNameWithLoc original,
1785-
IndexSubset *indices);
1780+
static TransposeAttr *create(ASTContext &context, bool implicit,
1781+
SourceLoc atLoc, SourceRange baseRange,
1782+
TypeRepr *baseType, DeclNameWithLoc original,
1783+
IndexSubset *indices);
17861784

17871785
TypeRepr *getBaseType() const { return BaseType; }
17881786
DeclNameWithLoc getOriginalFunctionName() const {
@@ -1815,10 +1813,13 @@ class TransposingAttr final
18151813
}
18161814

18171815
static bool classof(const DeclAttribute *DA) {
1818-
return DA->getKind() == DAK_Transposing;
1816+
return DA->getKind() == DAK_Transpose;
18191817
}
18201818
};
18211819

1820+
// TODO(TF-999): Remove deprecated `@transposing` attribute.
1821+
using TransposingAttr = TransposeAttr;
1822+
18221823
/// Relates a property to its projection value property, as described by a property wrapper. For
18231824
/// example, given
18241825
/// \code

include/swift/AST/DiagnosticsParse.def

+6-3
Original file line numberDiff line numberDiff line change
@@ -1572,11 +1572,14 @@ WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
15721572
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
15731573
"instead", ())
15741574

1575-
// transposing
1576-
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
1575+
// transpose
1576+
ERROR(attr_transpose_expected_original_name,PointsToFirstBadToken,
15771577
"expected an original function name", ())
1578-
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
1578+
ERROR(attr_transpose_expected_label_linear_or_wrt,none,
15791579
"expected 'wrt:'", ())
1580+
WARNING(attr_transposing_deprecated,PointsToFirstBadToken,
1581+
"'@transposing' attribute is deprecated; use '@transpose(of:)' instead",
1582+
())
15801583

15811584
// transposing `wrt` parameters clause
15821585
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,

include/swift/AST/DiagnosticsSema.def

+7-7
Original file line numberDiff line numberDiff line change
@@ -3024,17 +3024,17 @@ ERROR(derivative_attr_original_stored_property_unsupported,none,
30243024
ERROR(derivative_attr_original_already_has_derivative,none,
30253025
"a derivative already exists for %0", (DeclName))
30263026

3027-
// transposing
3027+
// @transpose
30283028
ERROR(transpose_params_clause_param_not_differentiable,none,
30293029
"can only transpose with respect to parameters that conform to "
30303030
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
3031-
ERROR(transposing_attr_overload_not_found,none,
3031+
ERROR(transpose_attr_overload_not_found,none,
30323032
"could not find function %0 with expected type %1", (DeclName, Type))
3033-
ERROR(transposing_attr_cannot_use_named_wrt_params,none,
3034-
"cannot use named 'wrt' parameters in '@transposing' attribute, found %0",
3035-
(Identifier))
3036-
ERROR(transposing_attr_result_value_not_differentiable,none,
3037-
"'@transposing' attribute requires original function result %0 to "
3033+
ERROR(transpose_attr_cannot_use_named_wrt_params,none,
3034+
"cannot use named 'wrt' parameters in '@transpose(of:)' attribute, found "
3035+
"%0", (Identifier))
3036+
ERROR(transpose_attr_result_value_not_differentiable,none,
3037+
"'@transpose(of:)' attribute requires original function result %0 to "
30383038
"conform to 'Differentiable'", (Type))
30393039

30403040
// differentiation `wrt` parameters clause

include/swift/AST/Types.h

+4-5
Original file line numberDiff line numberDiff line change
@@ -3142,7 +3142,7 @@ class AnyFunctionType : public TypeBase {
31423142
///
31433143
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
31443144
/// first. This should be used during type-checking, e.g. type-checking
3145-
/// `@differentiable`, `@derivative`, and `@transposing` attributes.
3145+
/// `@differentiable`, `@derivative`, and `@transpose` attributes.
31463146
///
31473147
/// \note The original function type (`self`) need not be `@differentiable`.
31483148
/// The resulting function will preserve all `ExtInfo` of the original
@@ -3158,11 +3158,10 @@ class AnyFunctionType : public TypeBase {
31583158
/// corresponding original function type.
31593159
AnyFunctionType *getAutoDiffOriginalFunctionType();
31603160

3161-
/// Given the type of a transposing derivative function, returns the
3162-
/// corresponding original function type.
3161+
/// Given the type of a transpose function, returns the corresponding original
3162+
/// function type.
31633163
AnyFunctionType *
3164-
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices,
3165-
bool wrtSelf);
3164+
getTransposeOriginalFunctionType(IndexSubset *wrtParamIndices, bool wrtSelf);
31663165

31673166
AnyFunctionType *getWithoutDifferentiability() const;
31683167

include/swift/Parse/Parser.h

+8-3
Original file line numberDiff line numberDiff line change
@@ -1013,9 +1013,14 @@ class Parser {
10131013
ParserResult<DerivativeAttr> parseDifferentiatingAttribute(SourceLoc AtLoc,
10141014
SourceLoc Loc);
10151015

1016-
/// Parse the @transposing attribute.
1017-
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
1018-
SourceLoc Loc);
1016+
/// Parse the @transpose attribute.
1017+
ParserResult<TransposeAttr> parseTransposeAttribute(SourceLoc AtLoc,
1018+
SourceLoc Loc);
1019+
1020+
/// Parse the deprecated @transposing attribute.
1021+
// TODO(TF-999): Remove the deprecated `@transposing` attribute.
1022+
ParserResult<TransposeAttr> parseTransposingAttribute(SourceLoc AtLoc,
1023+
SourceLoc Loc);
10191024

10201025
/// Parse the @quoted attribute.
10211026
ParserResult<QuotedAttr> parseQuotedAttribute(SourceLoc AtLoc, SourceLoc Loc);

lib/AST/Attr.cpp

+29-29
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,8 @@ StringRef DeclAttribute::getAttrName() const {
11101110
return "differentiable";
11111111
case DAK_Derivative:
11121112
return "derivative";
1113+
case DAK_Transpose:
1114+
return "transpose";
11131115
case DAK_Differentiating:
11141116
return "differentiating";
11151117
case DAK_Transposing:
@@ -1608,45 +1610,43 @@ DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
16081610
std::move(originalName), indices);
16091611
}
16101612

1611-
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1612-
SourceRange baseRange, TypeRepr *baseType,
1613-
DeclNameWithLoc originalName,
1614-
ArrayRef<ParsedAutoDiffParameter> params)
1615-
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1613+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1614+
SourceRange baseRange, TypeRepr *baseType,
1615+
DeclNameWithLoc originalName,
1616+
ArrayRef<ParsedAutoDiffParameter> params)
1617+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
16161618
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16171619
NumParsedParameters(params.size()) {
16181620
std::uninitialized_copy(params.begin(), params.end(),
16191621
getTrailingObjects<ParsedAutoDiffParameter>());
16201622
}
16211623

1622-
TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
1623-
SourceRange baseRange, TypeRepr *baseType,
1624-
DeclNameWithLoc originalName,
1625-
IndexSubset *indices)
1626-
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1624+
TransposeAttr::TransposeAttr(bool implicit, SourceLoc atLoc,
1625+
SourceRange baseRange, TypeRepr *baseType,
1626+
DeclNameWithLoc originalName, IndexSubset *indices)
1627+
: DeclAttribute(DAK_Transpose, atLoc, baseRange, implicit),
16271628
BaseType(baseType), OriginalFunctionName(std::move(originalName)),
16281629
ParameterIndices(indices) {}
16291630

1630-
TransposingAttr *
1631-
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1632-
SourceRange baseRange, TypeRepr *baseType,
1633-
DeclNameWithLoc original,
1634-
ArrayRef<ParsedAutoDiffParameter> params) {
1631+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1632+
SourceLoc atLoc, SourceRange baseRange,
1633+
TypeRepr *baseType,
1634+
DeclNameWithLoc originalName,
1635+
ArrayRef<ParsedAutoDiffParameter> params) {
16351636
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1636-
void *mem = context.Allocate(size, alignof(TransposingAttr));
1637-
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1638-
std::move(original), params);
1639-
}
1640-
1641-
TransposingAttr *
1642-
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1643-
SourceRange baseRange, TypeRepr *baseType,
1644-
DeclNameWithLoc original,
1645-
IndexSubset *indices) {
1646-
void *mem =
1647-
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
1648-
return new (mem) TransposingAttr(implicit, atLoc, baseRange, baseType,
1649-
std::move(original), indices);
1637+
void *mem = context.Allocate(size, alignof(TransposeAttr));
1638+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1639+
std::move(originalName), params);
1640+
}
1641+
1642+
TransposeAttr *TransposeAttr::create(ASTContext &context, bool implicit,
1643+
SourceLoc atLoc, SourceRange baseRange,
1644+
TypeRepr *baseType,
1645+
DeclNameWithLoc originalName,
1646+
IndexSubset *indices) {
1647+
void *mem = context.Allocate(sizeof(TransposeAttr), alignof(TransposeAttr));
1648+
return new (mem) TransposeAttr(implicit, atLoc, baseRange, baseType,
1649+
std::move(originalName), indices);
16501650
}
16511651

16521652
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,

lib/AST/Type.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -4897,8 +4897,7 @@ AnyFunctionType *AnyFunctionType::getTransposeOriginalFunctionType(
48974897
assert(originalResult);
48984898

48994899
SmallVector<TupleTypeElt, 4> transposeResultTypes;
4900-
// Return type of '@transposing' function can have single type or tuples
4901-
// of types.
4900+
// Return type of transpose function can be a singular type or a tuple type.
49024901
if (auto transposeResultTupleType = transposeResult->getAs<TupleType>()) {
49034902
transposeResultTypes.append(transposeResultTupleType->getElements().begin(),
49044903
transposeResultTupleType->getElements().end());

0 commit comments

Comments
 (0)