Skip to content

[Autodiff] Derivative Registration for the Get and Set Accessors #32614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/swift/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "swift/AST/Ownership.h"
#include "swift/AST/PlatformKind.h"
#include "swift/AST/Requirement.h"
#include "swift/AST/StorageImpl.h"
#include "swift/AST/TrailingCallArguments.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
Expand Down Expand Up @@ -1718,6 +1719,7 @@ class OriginallyDefinedInAttr: public DeclAttribute {
struct DeclNameRefWithLoc {
DeclNameRef Name;
DeclNameLoc Loc;
Optional<AccessorKind> AccessorKind;
};

/// Attribute that marks a function as differentiable.
Expand Down
6 changes: 6 additions & 0 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -3091,6 +3091,10 @@ ERROR(derivative_attr_class_member_dynamic_self_result_unsupported,none,
ERROR(derivative_attr_nonfinal_class_init_unsupported,none,
"cannot register derivative for 'init' in a non-final class; consider "
"making %0 final", (Type))
// TODO(SR-13096): Remove this temporary diagnostic.
ERROR(derivative_attr_class_setter_unsupported,none,
"cannot yet register derivative for class property or subscript setters",
())
ERROR(derivative_attr_original_already_has_derivative,none,
"a derivative already exists for %0", (DeclName))
NOTE(derivative_attr_duplicate_note,none,
Expand Down Expand Up @@ -3129,6 +3133,8 @@ NOTE(transpose_attr_wrt_self_self_type_mismatch_note,none,
ERROR(autodiff_attr_original_decl_invalid_kind,none,
"%0 is not a 'func', 'init', 'subscript', or 'var' computed property "
"declaration", (DeclNameRef))
ERROR(autodiff_attr_accessor_not_found,none,
"%0 does not have a '%1' accessor", (DeclNameRef, StringRef))
ERROR(autodiff_attr_original_decl_none_valid_found,none,
"could not find function %0 with expected type %1", (DeclNameRef, Type))
ERROR(autodiff_attr_original_decl_not_same_type_context,none,
Expand Down
45 changes: 44 additions & 1 deletion lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1037,11 +1037,25 @@ bool Parser::parseDifferentiableAttributeArguments(
return false;
}

// Helper function that returns the accessorkind if a token is an accessor label.
static Optional<AccessorKind> isAccessorLabel(const Token& token) {
if (token.is(tok::identifier)) {
StringRef tokText = token.getText();
for (auto accessor : allAccessorKinds()) {
if (tokText == getAccessorLabel(accessor)) {
return accessor;
}
}
}
return None;
}

/// Helper function that parses 'type-identifier' for `parseQualifiedDeclName`.
/// Returns true on error. Sets `baseType` to the parsed base type if present,
/// or to `nullptr` if not. A missing base type is not considered an error.
static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
baseType = nullptr;
Parser::BacktrackingScope backtrack(P);

// If base type cannot be parsed, return false (no error).
if (!P.canParseBaseTypeForQualifiedDeclName())
Expand All @@ -1057,6 +1071,18 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
// `parseTypeIdentifier(/*isParsingQualifiedDeclName*/ true)` leaves the
// leading period unparsed to avoid syntax verification errors.
assert(P.startsWithSymbol(P.Tok, '.') && "false");

// Check if this is a reference to an accessor in a computed property.
// FIXME: There is an ambiguity here because instead of a computed
// property with an accessor, this could be a type with a function
// name like an accessor.
if (P.Tok.is(tok::period)) {
const Token &nextToken = P.peekToken();
if (isAccessorLabel(nextToken) != None)
return false;
}

backtrack.cancelBacktrack();
P.consumeStartingCharacterOfCurrentToken(tok::period);

// Set base type and return false (no error).
Expand All @@ -1079,6 +1105,7 @@ static bool parseBaseTypeForQualifiedDeclName(Parser &P, TypeRepr *&baseType) {
static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
TypeRepr *&baseType,
DeclNameRefWithLoc &original) {
{
SyntaxParsingContext DeclNameContext(P.SyntaxContext,
SyntaxKind::QualifiedDeclName);
// Parse base type.
Expand All @@ -1092,7 +1119,23 @@ static bool parseQualifiedDeclName(Parser &P, Diag<> nameParseError,
Parser::DeclNameFlag::AllowOperators);
// The base type is optional, but the final unqualified declaration name is
// not. If name could not be parsed, return true for error.
return !original.Name;
if (!original.Name)
return true;
}

// Parse to see if this is an accessor and set it's type. This is an optional field.
if (P.Tok.is(tok::period)) {
const Token &nextToken = P.peekToken();
Optional<AccessorKind> kind = isAccessorLabel(nextToken);
if (kind != None) {
original.AccessorKind = kind;
P.consumeIf(tok::period);
P.consumeIf(tok::identifier);
}
}

return false;

}

/// Parse a `@derivative(of:)` attribute, returning true on error.
Expand Down
76 changes: 62 additions & 14 deletions lib/Sema/TypeCheckAttr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "swift/AST/ParameterList.h"
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/SourceFile.h"
#include "swift/AST/StorageImpl.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/AST/Types.h"
#include "swift/Parse/Lexer.h"
Expand Down Expand Up @@ -3613,12 +3614,14 @@ static IndexSubset *computeDifferentiabilityParameters(
// If the function declaration cannot be resolved, emits a diagnostic and
// returns nullptr.
static AbstractFunctionDecl *findAbstractFunctionDecl(
DeclNameRef funcName, SourceLoc funcNameLoc, Type baseType,
DeclNameRef funcName, SourceLoc funcNameLoc,
Optional<AccessorKind> accessorKind, Type baseType,
DeclContext *lookupContext,
const std::function<bool(AbstractFunctionDecl *)> &isValidCandidate,
const std::function<void()> &noneValidDiagnostic,
const std::function<void()> &ambiguousDiagnostic,
const std::function<void()> &notFunctionDiagnostic,
const std::function<void()> &missingAccessorDiagnostic,
NameLookupOptions lookupOptions,
const Optional<std::function<bool(AbstractFunctionDecl *)>>
&hasValidTypeCtx,
Expand All @@ -3644,6 +3647,7 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
bool wrongTypeContext = false;
bool ambiguousFuncDecl = false;
bool foundInvalid = false;
bool missingAccessor = false;

// Filter lookup results.
for (auto choice : results) {
Expand All @@ -3652,10 +3656,21 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
continue;
// Cast the candidate to an `AbstractFunctionDecl`.
auto *candidate = dyn_cast<AbstractFunctionDecl>(decl);
// If the candidate is an `AbstractStorageDecl`, use its getter as the
// candidate.
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl))
candidate = asd->getAccessor(AccessorKind::Get);
// If the candidate is an `AbstractStorageDecl`, use one of its accessors as
// the candidate.
if (auto *asd = dyn_cast<AbstractStorageDecl>(decl)) {
// If accessor kind is specified, use corresponding accessor from the
// candidate. Otherwise, use the getter by default.
if (accessorKind != None) {
candidate = asd->getAccessor(accessorKind.getValue());
// Error if candidate is missing the requested accessor.
if (!candidate)
missingAccessor = true;
} else
candidate = asd->getAccessor(AccessorKind::Get);
} else if (accessorKind != None) {
missingAccessor = true;
}
if (!candidate) {
notFunction = true;
continue;
Expand All @@ -3675,8 +3690,9 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
}
resolvedCandidate = candidate;
}

// If function declaration was resolved, return it.
if (resolvedCandidate)
if (resolvedCandidate && !missingAccessor)
return resolvedCandidate;

// Otherwise, emit the appropriate diagnostic and return nullptr.
Expand All @@ -3689,6 +3705,10 @@ static AbstractFunctionDecl *findAbstractFunctionDecl(
ambiguousDiagnostic();
return nullptr;
}
if (missingAccessor) {
missingAccessorDiagnostic();
return nullptr;
}
if (wrongTypeContext) {
assert(invalidTypeCtxDiagnostic &&
"Type context diagnostic should've been specified");
Expand Down Expand Up @@ -4433,6 +4453,13 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
diag::autodiff_attr_original_decl_invalid_kind,
originalName.Name);
};
auto missingAccessorDiagnostic = [&]() {
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
auto accessorLabel = getAccessorLabel(accessorKind);
diags.diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
originalName.Name, accessorLabel);
};

std::function<void()> invalidTypeContextDiagnostic = [&]() {
diags.diagnose(originalName.Loc,
diag::autodiff_attr_original_decl_not_same_type_context,
Expand Down Expand Up @@ -4477,15 +4504,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,

// Look up original function.
auto *originalAFD = findAbstractFunctionDecl(
originalName.Name, originalName.Loc.getBaseNameLoc(), baseType,
derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, lookupOptions,
hasValidTypeContext, invalidTypeContextDiagnostic);
originalName.Name, originalName.Loc.getBaseNameLoc(),
originalName.AccessorKind,
baseType, derivativeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
if (!originalAFD)
return true;
// Diagnose original stored properties. Stored properties cannot have custom
// registered derivatives.

if (auto *accessorDecl = dyn_cast<AccessorDecl>(originalAFD)) {
// Diagnose original stored properties. Stored properties cannot have custom
// registered derivatives.
auto *asd = accessorDecl->getStorage();
if (asd->hasStorage()) {
diags.diagnose(originalName.Loc,
Expand All @@ -4495,6 +4524,17 @@ static bool typeCheckDerivativeAttr(ASTContext &Ctx, Decl *D,
asd->getName());
return true;
}
// Diagnose original class property and subscript setters.
// TODO(SR-13096): Fix derivative function typing results regarding
// class-typed function parameters.
if (asd->getDeclContext()->getSelfClassDecl() &&
accessorDecl->getAccessorKind() == AccessorKind::Set) {
diags.diagnose(originalName.Loc,
diag::derivative_attr_class_setter_unsupported);
diags.diagnose(originalAFD->getLoc(), diag::decl_declared_here,
asd->getName());
return true;
}
}
// Diagnose if original function is an invalid class member.
bool isOriginalClassMember =
Expand Down Expand Up @@ -5002,6 +5042,13 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
diag::autodiff_attr_original_decl_invalid_kind,
originalName.Name);
};
auto missingAccessorDiagnostic = [&]() {
auto accessorKind = originalName.AccessorKind.getValueOr(AccessorKind::Get);
auto accessorLabel = getAccessorLabel(accessorKind);
diagnose(originalName.Loc, diag::autodiff_attr_accessor_not_found,
originalName.Name, accessorLabel);
};

std::function<void()> invalidTypeContextDiagnostic = [&]() {
diagnose(originalName.Loc,
diag::autodiff_attr_original_decl_not_same_type_context,
Expand Down Expand Up @@ -5032,8 +5079,9 @@ void AttributeChecker::visitTransposeAttr(TransposeAttr *attr) {
if (attr->getBaseTypeRepr())
funcLoc = attr->getBaseTypeRepr()->getLoc();
auto *originalAFD = findAbstractFunctionDecl(
originalName.Name, funcLoc, baseType, transposeTypeCtx, isValidOriginal,
noneValidDiagnostic, ambiguousDiagnostic, notFunctionDiagnostic,
originalName.Name, funcLoc, originalName.AccessorKind, baseType,
transposeTypeCtx, isValidOriginal, noneValidDiagnostic,
ambiguousDiagnostic, notFunctionDiagnostic, missingAccessorDiagnostic,
lookupOptions, hasValidTypeContext, invalidTypeContextDiagnostic);
if (!originalAFD) {
attr->setInvalid();
Expand Down
4 changes: 2 additions & 2 deletions lib/Serialization/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4389,7 +4389,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
parameters);

DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc(), None};
auto derivativeKind =
getActualAutoDiffDerivativeFunctionKind(rawDerivativeKind);
if (!derivativeKind)
Expand Down Expand Up @@ -4418,7 +4418,7 @@ llvm::Error DeclDeserializer::deserializeDeclAttributes() {
scratch, isImplicit, origNameId, origDeclId, parameters);

DeclNameRefWithLoc origName{
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc()};
DeclNameRef(MF.getDeclBaseName(origNameId)), DeclNameLoc(), None};
auto *origDecl = cast<AbstractFunctionDecl>(MF.getDecl(origDeclId));
llvm::SmallBitVector parametersBitVector(parameters.size());
for (unsigned i : indices(parameters))
Expand Down
34 changes: 34 additions & 0 deletions test/AutoDiff/Parse/derivative_attr_parse.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,27 @@ func dfoo(x: Float) -> (value: Float, differential: (Float) -> (Float)) {
return (x, { $0 })
}

@derivative(of: property.get) // ok
func dPropertyGetter() -> ()

@derivative(of: subscript.get) // ok
func dSubscriptGetter() -> ()

@derivative(of: subscript(_:label:).get) // ok
func dLabeledSubscriptGetter() -> ()

@derivative(of: property.set) // ok
func dPropertySetter() -> ()

@derivative(of: subscript.set) // ok
func dSubscriptSetter() -> ()

@derivative(of: subscript(_:label:).set) // ok
func dLabeledSubscriptSetter() -> ()

@derivative(of: nestedType.name) // ok
func dNestedTypeFunc() -> ()

/// Bad

// expected-error @+2 {{expected an original function name}}
Expand Down Expand Up @@ -98,3 +119,16 @@ func testLocalDerivativeRegistration() {
@derivative(of: sin)
func dsin()
}


func testLocalDerivativeRegistration() {
// expected-error @+1 {{attribute '@derivative' can only be used in a non-local scope}}
@derivative(of: sin)
func dsin()
}

// expected-error @+2 {{expected ',' separator}}
// expected-error @+1 {{expected declaration}}
@derivative(of: nestedType.name.set)
func dNestedTypePropertySetter() -> ()

Loading