Skip to content

[AutoDiff] Rename @differentiating to @derivative(of:). #28481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/swift/AST/Attr.def
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ DECL_ATTR(differentiable, Differentiable,
AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove,
91)
DECL_ATTR(differentiating, Differentiating,
DECL_ATTR(derivative, Derivative,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
NotSerialized, 92)
Expand Down Expand Up @@ -540,6 +540,11 @@ DECL_ATTR(quoted, Quoted,
OnFunc |
ABIStableToAdd | ABIStableToRemove | APIStableToAdd | APIStableToRemove,
97)
// TODO(TF-999): Remove deprecated `@differentiating` attribute.
DECL_ATTR(differentiating, Differentiating,
OnFunc | LongAttribute | AllowMultipleAttributes |
ABIStableToAdd | ABIBreakingToRemove | APIStableToAdd | APIBreakingToRemove |
NotSerialized, 98)
// SWIFT_ENABLE_TENSORFLOW END

#undef TYPE_ATTR
Expand Down
42 changes: 21 additions & 21 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1670,12 +1670,11 @@ class DifferentiableAttr final
/// Attribute that registers a function as a derivative of another function.
///
/// Examples:
/// @differentiating(sin(_:))
/// @differentiating(+, wrt: (lhs, rhs))
class DifferentiatingAttr final
/// @derivative(of: sin(_:))
/// @derivative(of: +, wrt: (lhs, rhs))
class DerivativeAttr final
: public DeclAttribute,
private llvm::TrailingObjects<DifferentiatingAttr,
ParsedAutoDiffParameter> {
private llvm::TrailingObjects<DerivativeAttr, ParsedAutoDiffParameter> {
friend TrailingObjects;

/// The original function name.
Expand All @@ -1687,24 +1686,22 @@ class DifferentiatingAttr final
/// The differentiation parameters' indices, resolved by the type checker.
IndexSubset *ParameterIndices = nullptr;

explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

explicit DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
IndexSubset *indices);
explicit DerivativeAttr(bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, IndexSubset *indices);

public:
static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);
static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params);

static DifferentiatingAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original,
IndexSubset *indices);
static DerivativeAttr *create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc original, IndexSubset *indices);

DeclNameWithLoc getOriginalFunctionName() const {
return OriginalFunctionName;
Expand Down Expand Up @@ -1736,10 +1733,13 @@ class DifferentiatingAttr final
}

static bool classof(const DeclAttribute *DA) {
return DA->getKind() == DAK_Differentiating;
return DA->getKind() == DAK_Derivative;
}
};


// TODO(TF-999): Remove deprecated `@differentiating` attribute.
using DifferentiatingAttr = DerivativeAttr;

/// Attribute that registers a function as a transpose of another function.
///
/// Examples:
Expand Down
7 changes: 5 additions & 2 deletions include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -1561,13 +1561,16 @@ ERROR(attr_differentiable_expected_label,none,
"expected either 'wrt:' or a function specifier label, e.g. 'jvp:', "
"or 'vjp:'", ())

// differentiating
ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
// derivative
ERROR(attr_derivative_expected_original_name,PointsToFirstBadToken,
"expected an original function name", ())
ERROR(attr_missing_label,PointsToFirstBadToken,
"missing label '%0:' in '@%1' attribute", (StringRef, StringRef))
ERROR(attr_expected_label,none,
"expected label '%0:' in '@%1' attribute", (StringRef, StringRef))
WARNING(attr_differentiating_deprecated,PointsToFirstBadToken,
"'@differentiating' attribute is deprecated; use '@derivative(of:)' "
"instead", ())

// transposing
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
Expand Down
32 changes: 16 additions & 16 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2995,33 +2995,33 @@ ERROR(overriding_decl_missing_differentiable_attr,none,
NOTE(protocol_witness_missing_differentiable_attr,none,
"candidate is missing attribute '%0'", (StringRef))

// @differentiating
ERROR(differentiating_attr_expected_result_tuple,none,
"'@differentiating' attribute requires function to return a two-element tuple of type "
// @derivative
ERROR(derivative_attr_expected_result_tuple,none,
"'@derivative(of:)' attribute requires function to return a two-element tuple of type "
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One topic to discuss is whether to show @derivative(of:) or @derivative in user-facing messages.
I chose @derivative(of:), following the precedent of dynamicReplacement(for:).

Code comments mostly use @derivative without the label.

Copy link
Contributor

@rxwei rxwei Nov 26, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dynamicReplacement(for:) doesn't take additional arguments like we do, but @derivative(of:wrt:) is not good either. This works for me.

"'(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or "
"'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
ERROR(differentiating_attr_invalid_result_tuple_value_label,none,
"'@differentiating' attribute requires function to return a two-element "
ERROR(derivative_attr_invalid_result_tuple_value_label,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (first element must have label 'value:')", ())
ERROR(differentiating_attr_invalid_result_tuple_func_label,none,
"'@differentiating' attribute requires function to return a two-element "
ERROR(derivative_attr_invalid_result_tuple_func_label,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (second element must have label 'pullback:' or 'differential:')", ())
ERROR(differentiating_attr_result_value_not_differentiable,none,
"'@differentiating' attribute requires function to return a two-element "
ERROR(derivative_attr_result_value_not_differentiable,none,
"'@derivative(of:)' attribute requires function to return a two-element "
"tuple (first element type %0 must conform to 'Differentiable')", (Type))
ERROR(differentiating_attr_result_func_type_mismatch,none,
ERROR(derivative_attr_result_func_type_mismatch,none,
"function result's %0 type does not match %1", (Identifier, DeclName))
NOTE(differentiating_attr_result_func_type_mismatch_note,none,
NOTE(derivative_attr_result_func_type_mismatch_note,none,
"%0 does not have expected type %1", (Identifier, Type))
NOTE(differentiating_attr_result_func_original_note,none,
NOTE(derivative_attr_result_func_original_note,none,
"%0 defined here", (DeclName))
ERROR(differentiating_attr_overload_not_found,none,
ERROR(derivative_attr_overload_not_found,none,
"could not find function %0 with expected type %1", (DeclName, Type))
ERROR(differentiating_attr_not_in_same_file_as_original,none,
ERROR(derivative_attr_not_in_same_file_as_original,none,
"derivative not in the same file as the original function", ())
ERROR(differentiating_attr_original_stored_property_unsupported,none,
ERROR(derivative_attr_original_stored_property_unsupported,none,
"cannot register derivative for stored property %0", (DeclName))
ERROR(differentiating_attr_original_already_has_derivative,none,
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))

// transposing
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3142,7 +3142,7 @@ class AnyFunctionType : public TypeBase {
///
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
/// first. This should be used during type-checking, e.g. type-checking
/// `@differentiable`, `@differentiating`, and `@transposing` attributes.
/// `@differentiable`, `@derivative`, and `@transposing` attributes.
///
/// \note The original function type (`self`) need not be `@differentiable`.
/// The resulting function will preserve all `ExtInfo` of the original
Expand Down
11 changes: 8 additions & 3 deletions include/swift/Parse/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -1004,9 +1004,14 @@ class Parser {
bool parseTransposingParametersClause(
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);

/// Parse the @differentiating attribute.
ParserResult<DifferentiatingAttr>
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);
/// Parse the @derivative attribute.
ParserResult<DerivativeAttr> parseDerivativeAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse the deprecated @differentiating attribute.
// TODO(TF-999): Remove the deprecated `@differentiating` attribute.
ParserResult<DerivativeAttr> parseDifferentiatingAttribute(SourceLoc AtLoc,
SourceLoc Loc);

/// Parse the @transposing attribute.
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
Expand Down
4 changes: 2 additions & 2 deletions include/swift/SIL/SILDifferentiabilityWitness.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// indices, derivative generic signature) to derivative functions (JVP and VJP).
//
// SIL differentiability witnesses are generated from the `@differentiable`
// and `@differentiating` attributes AST declaration attributes.
// and `@derivative` attribute AST declaration attributes.
// Differentiability witnesses are canonicalized by the differentiation SIL
// transform, which fills in missing derivative functions. Canonical
// differentiability witnesses from other modules can be deserialized to look up
Expand Down Expand Up @@ -60,7 +60,7 @@ class SILDifferentiabilityWitness
/// Whether or not this differentiability witness is serialized, which allows
/// devirtualization from another module.
bool IsSerialized;
/// The AST `@differentiable` or `@differentiating` attribute from which the
/// The AST `@differentiable` or `@derivative` attribute from which the
/// differentiability witness is generated. Used for diagnostics.
/// Null if the differentiability witness is parsed from SIL or if it is
/// deserialized.
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ class ScopeCreator final {
// necessary to avoid verification failure:
// `ASTScopeImpl::verifyThatChildrenAreContainedWithin`.
// Perhaps this check is no longer necessary after TF-835: robust
// `@differentiating` attribute lowering.
// `@derivative` attribute lowering.
if (!diffAttr->isImplicit())
sortedDifferentiableAttrs.push_back(diffAttr);
for (auto *diffAttr : sortBySourceRange(sortedDifferentiableAttrs))
Expand Down
65 changes: 33 additions & 32 deletions lib/AST/Attr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -921,10 +921,10 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
}

// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiating: {
Printer.printAttrName("@differentiating");
Printer << '(';
auto *attr = cast<DifferentiatingAttr>(this);
case DAK_Derivative: {
Printer.printAttrName("@derivative");
Printer << "(of: ";
auto *attr = cast<DerivativeAttr>(this);
auto *derivative = cast<AbstractFunctionDecl>(D);
Printer << attr->getOriginalFunctionName().Name;
auto diffParamsString = getDifferentiationParametersClauseString(
Expand All @@ -934,7 +934,7 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
Printer << ')';
break;
}

// SWIFT_ENABLE_TENSORFLOW
case DAK_Transposing: {
Printer.printAttrName("@transposing");
Expand Down Expand Up @@ -1108,6 +1108,8 @@ StringRef DeclAttribute::getAttrName() const {
// SWIFT_ENABLE_TENSORFLOW
case DAK_Differentiable:
return "differentiable";
case DAK_Derivative:
return "derivative";
case DAK_Differentiating:
return "differentiating";
case DAK_Transposing:
Expand Down Expand Up @@ -1568,43 +1570,42 @@ void DifferentiableAttr::print(llvm::raw_ostream &OS, const Decl *D,
}

// SWIFT_ENABLE_TENSORFLOW
DifferentiatingAttr::DifferentiatingAttr(
bool implicit, SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc originalName, ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)),
NumParsedParameters(params.size()) {
std::copy(params.begin(), params.end(),
getTrailingObjects<ParsedAutoDiffParameter>());
}

DifferentiatingAttr::DifferentiatingAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
DerivativeAttr::DerivativeAttr(bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc originalName,
IndexSubset *indices)
: DeclAttribute(DAK_Derivative, atLoc, baseRange, implicit),
OriginalFunctionName(std::move(originalName)), ParameterIndices(indices) {
}

DifferentiatingAttr *
DifferentiatingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc original,
ArrayRef<ParsedAutoDiffParameter> params) {
DerivativeAttr *
DerivativeAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
SourceRange baseRange, DeclNameWithLoc originalName,
ArrayRef<ParsedAutoDiffParameter> params) {
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
void *mem = context.Allocate(size, alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
std::move(original), params);
}

DifferentiatingAttr *DifferentiatingAttr::create(ASTContext &context,
bool implicit, SourceLoc atLoc,
SourceRange baseRange,
DeclNameWithLoc original,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(DifferentiatingAttr),
alignof(DifferentiatingAttr));
return new (mem) DifferentiatingAttr(implicit, atLoc, baseRange,
std::move(original), indices);
void *mem = context.Allocate(size, alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
std::move(originalName), params);
}

DerivativeAttr *DerivativeAttr::create(ASTContext &context, bool implicit,
SourceLoc atLoc, SourceRange baseRange,
DeclNameWithLoc originalName,
IndexSubset *indices) {
void *mem = context.Allocate(sizeof(DerivativeAttr), alignof(DerivativeAttr));
return new (mem) DerivativeAttr(implicit, atLoc, baseRange,
std::move(originalName), indices);
}

TransposingAttr::TransposingAttr(bool implicit, SourceLoc atLoc,
Expand Down
4 changes: 2 additions & 2 deletions lib/AST/Decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,9 @@ SourceRange Decl::getSourceRangeIncludingAttrs() const {
for (auto Attr : getAttrs()) {
// SWIFT_ENABLE_TENSORFLOW
// Skip implicitly `@differentiable` attribute generated during
// `@differentiating` attribute type-checking.
// `@derivative` attribute type-checking.
// TODO(TF-835): Instead of generating implicit `@differentiable`
// attributes, lower `@differentiating` attributes to `[differentiable]`
// attributes, lower `@derivative` attributes to `[differentiable]`
// attributes on the referenced declaration.
if (auto *diffAttr = dyn_cast<DifferentiableAttr>(Attr))
if (diffAttr->isImplicit())
Expand Down
Loading