Skip to content

Extend operator decls to allow any designated nominal type for lookup. #19756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 30 additions & 28 deletions include/swift/AST/Decl.h
Original file line number Diff line number Diff line change
Expand Up @@ -6383,43 +6383,45 @@ class OperatorDecl : public Decl {

Identifier name;

Identifier DesignatedProtocolName;
SourceLoc DesignatedProtocolNameLoc;
ProtocolDecl *DesignatedProtocol = nullptr;
Identifier DesignatedNominalTypeName;
SourceLoc DesignatedNominalTypeNameLoc;
NominalTypeDecl *DesignatedNominalType = nullptr;

public:
OperatorDecl(DeclKind kind, DeclContext *DC, SourceLoc OperatorLoc,
Identifier Name, SourceLoc NameLoc,
Identifier DesignatedProtocolName = Identifier(),
SourceLoc DesignatedProtocolNameLoc = SourceLoc())
Identifier DesignatedNominalTypeName = Identifier(),
SourceLoc DesignatedNominalTypeNameLoc = SourceLoc())
: Decl(kind, DC), OperatorLoc(OperatorLoc), NameLoc(NameLoc), name(Name),
DesignatedProtocolName(DesignatedProtocolName),
DesignatedProtocolNameLoc(DesignatedProtocolNameLoc) {}
DesignatedNominalTypeName(DesignatedNominalTypeName),
DesignatedNominalTypeNameLoc(DesignatedNominalTypeNameLoc) {}

OperatorDecl(DeclKind kind, DeclContext *DC, SourceLoc OperatorLoc,
Identifier Name, SourceLoc NameLoc,
ProtocolDecl *DesignatedProtocol)
NominalTypeDecl *DesignatedNominalType)
: Decl(kind, DC), OperatorLoc(OperatorLoc), NameLoc(NameLoc), name(Name),
DesignatedProtocol(DesignatedProtocol) {}
DesignatedNominalType(DesignatedNominalType) {}

SourceLoc getLoc() const { return NameLoc; }

SourceLoc getOperatorLoc() const { return OperatorLoc; }
SourceLoc getNameLoc() const { return NameLoc; }
Identifier getName() const { return name; }

Identifier getDesignatedProtocolName() const {
return DesignatedProtocolName;
Identifier getDesignatedNominalTypeName() const {
return DesignatedNominalTypeName;
}

SourceLoc getDesignatedProtocolNameLoc() const {
return DesignatedProtocolNameLoc;
SourceLoc getDesignatedNominalTypeNameLoc() const {
return DesignatedNominalTypeNameLoc;
}

ProtocolDecl *getDesignatedProtocol() const { return DesignatedProtocol; }
NominalTypeDecl *getDesignatedNominalType() const {
return DesignatedNominalType;
}

void setDesignatedProtocol(ProtocolDecl *protocol) {
DesignatedProtocol = protocol;
void setDesignatedNominalType(NominalTypeDecl *nominal) {
DesignatedNominalType = nominal;
}

static bool classof(const Decl *D) {
Expand Down Expand Up @@ -6455,9 +6457,9 @@ class InfixOperatorDecl : public OperatorDecl {
InfixOperatorDecl(DeclContext *DC, SourceLoc operatorLoc, Identifier name,
SourceLoc nameLoc, SourceLoc colonLoc,
Identifier firstIdentifier, SourceLoc firstIdentifierLoc,
ProtocolDecl *designatedProtocol)
NominalTypeDecl *designatedNominalType)
: OperatorDecl(DeclKind::InfixOperator, DC, operatorLoc, name, nameLoc,
designatedProtocol),
designatedNominalType),
ColonLoc(colonLoc), FirstIdentifierLoc(firstIdentifierLoc),
FirstIdentifier(firstIdentifier) {}

Expand Down Expand Up @@ -6504,15 +6506,15 @@ class PrefixOperatorDecl : public OperatorDecl {
public:
PrefixOperatorDecl(DeclContext *DC, SourceLoc OperatorLoc, Identifier Name,
SourceLoc NameLoc,
Identifier DesignatedProtocolName = Identifier(),
SourceLoc DesignatedProtocolNameLoc = SourceLoc())
Identifier DesignatedNominalTypeName = Identifier(),
SourceLoc DesignatedNominalTypeNameLoc = SourceLoc())
: OperatorDecl(DeclKind::PrefixOperator, DC, OperatorLoc, Name, NameLoc,
DesignatedProtocolName, DesignatedProtocolNameLoc) {}
DesignatedNominalTypeName, DesignatedNominalTypeNameLoc) {}

PrefixOperatorDecl(DeclContext *DC, SourceLoc OperatorLoc, Identifier Name,
SourceLoc NameLoc, ProtocolDecl *DesignatedProtocol)
SourceLoc NameLoc, NominalTypeDecl *DesignatedNominalType)
: OperatorDecl(DeclKind::PrefixOperator, DC, OperatorLoc, Name, NameLoc,
DesignatedProtocol) {}
DesignatedNominalType) {}

SourceRange getSourceRange() const {
return { getOperatorLoc(), getNameLoc() };
Expand All @@ -6538,15 +6540,15 @@ class PostfixOperatorDecl : public OperatorDecl {
public:
PostfixOperatorDecl(DeclContext *DC, SourceLoc OperatorLoc, Identifier Name,
SourceLoc NameLoc,
Identifier DesignatedProtocolName = Identifier(),
SourceLoc DesignatedProtocolNameLoc = SourceLoc())
Identifier DesignatedNominalTypeName = Identifier(),
SourceLoc DesignatedNominalTypeNameLoc = SourceLoc())
: OperatorDecl(DeclKind::PostfixOperator, DC, OperatorLoc, Name, NameLoc,
DesignatedProtocolName, DesignatedProtocolNameLoc) {}
DesignatedNominalTypeName, DesignatedNominalTypeNameLoc) {}

PostfixOperatorDecl(DeclContext *DC, SourceLoc OperatorLoc, Identifier Name,
SourceLoc NameLoc, ProtocolDecl *DesignatedProtocol)
SourceLoc NameLoc, NominalTypeDecl *DesignatedNominalType)
: OperatorDecl(DeclKind::PostfixOperator, DC, OperatorLoc, Name, NameLoc,
DesignatedProtocol) {}
DesignatedNominalType) {}

SourceRange getSourceRange() const {
return { getOperatorLoc(), getNameLoc() };
Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/DiagnosticsParse.def
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ ERROR(operator_decl_no_fixity,none,
"operator must be declared as 'prefix', 'postfix', or 'infix'", ())

ERROR(operator_decl_trailing_comma,none,
"expected designated protocol in operator declaration", ())
"expected designated type in operator declaration", ())

// PrecedenceGroup
ERROR(precedencegroup_not_infix,none,
Expand Down
4 changes: 0 additions & 4 deletions include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -706,10 +706,6 @@ ERROR(unspaced_unary_operator,none,
"unary operators must not be juxtaposed; parenthesize inner expression",
())

ERROR(operators_designated_protocol_not_a_protocol,none,
"type %0 unexpected; expected a protocol type",
(Type))

ERROR(use_unresolved_identifier,none,
"use of unresolved %select{identifier|operator}1 %0", (DeclName, bool))
ERROR(use_unresolved_identifier_corrected,none,
Expand Down
6 changes: 3 additions & 3 deletions include/swift/Basic/LangOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ namespace swift {
/// Disable constraint system performance hacks.
bool DisableConstraintSolverPerformanceHacks = false;

/// \brief Enable experimental operator protocol designator feature.
bool EnableOperatorDesignatedProtocols = false;
/// \brief Enable experimental operator designated types feature.
bool EnableOperatorDesignatedTypes = false;

/// \brief Enable constraint solver support for experimental
/// operator protocol designator feature.
bool SolverEnableOperatorDesignatedProtocols = false;
bool SolverEnableOperatorDesignatedTypes = false;

/// The maximum depth to which to test decl circularity.
unsigned MaxCircularityDepth = 500;
Expand Down
12 changes: 6 additions & 6 deletions include/swift/Option/FrontendOptions.td
Original file line number Diff line number Diff line change
Expand Up @@ -373,13 +373,13 @@ def solver_disable_shrink :
def disable_constraint_solver_performance_hacks : Flag<["-"], "disable-constraint-solver-performance-hacks">,
HelpText<"Disable all the hacks in the constraint solver">;

def enable_operator_designated_protocols :
Flag<["-"], "enable-operator-designated-protocols">,
HelpText<"Enable operator designated protocols">;
def enable_operator_designated_types :
Flag<["-"], "enable-operator-designated-types">,
HelpText<"Enable operator designated types">;

def solver_enable_operator_designated_protocols :
Flag<["-"], "solver-enable-operator-designated-protocols">,
HelpText<"Enable operator designated protocols in constraint solver">;
def solver_enable_operator_designated_types :
Flag<["-"], "solver-enable-operator-designated-types">,
HelpText<"Enable operator designated types in constraint solver">;

def switch_checking_invocation_threshold_EQ : Joined<["-"],
"switch-checking-invocation-threshold=">;
Expand Down
2 changes: 1 addition & 1 deletion include/swift/Serialization/ModuleFormat.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ const uint16_t VERSION_MAJOR = 0;
/// describe what change you made. The content of this comment isn't important;
/// it just ensures a conflict if two people change the module format.
/// Don't worry about adhering to the 80-column limit for this line.
const uint16_t VERSION_MINOR = 451; // Last change: pattern initializer text
const uint16_t VERSION_MINOR = 452; // Last change: nominal types for operators

using DeclIDField = BCFixed<31>;

Expand Down
8 changes: 4 additions & 4 deletions lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2949,8 +2949,8 @@ void PrintAST::visitPrefixOperatorDecl(PrefixOperatorDecl *decl) {
[&]{
Printer.printName(decl->getName());
});
if (!decl->getDesignatedProtocolName().empty())
Printer << " : " << decl->getDesignatedProtocolName();
if (!decl->getDesignatedNominalTypeName().empty())
Printer << " : " << decl->getDesignatedNominalTypeName();
}

void PrintAST::visitPostfixOperatorDecl(PostfixOperatorDecl *decl) {
Expand All @@ -2960,8 +2960,8 @@ void PrintAST::visitPostfixOperatorDecl(PostfixOperatorDecl *decl) {
[&]{
Printer.printName(decl->getName());
});
if (!decl->getDesignatedProtocolName().empty())
Printer << " : " << decl->getDesignatedProtocolName();
if (!decl->getDesignatedNominalTypeName().empty())
Printer << " : " << decl->getDesignatedNominalTypeName();
}

void PrintAST::visitModuleDecl(ModuleDecl *decl) { }
Expand Down
8 changes: 4 additions & 4 deletions lib/Frontend/CompilerInvocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ static bool ParseLangArgs(LangOptions &Opts, ArgList &Args,
Opts.EnableExperimentalPropertyBehaviors |=
Args.hasArg(OPT_enable_experimental_property_behaviors);

Opts.EnableOperatorDesignatedProtocols |=
Args.hasArg(OPT_enable_operator_designated_protocols);
Opts.EnableOperatorDesignatedTypes |=
Args.hasArg(OPT_enable_operator_designated_types);

Opts.SolverEnableOperatorDesignatedProtocols |=
Args.hasArg(OPT_solver_enable_operator_designated_protocols);
Opts.SolverEnableOperatorDesignatedTypes |=
Args.hasArg(OPT_solver_enable_operator_designated_types);

if (auto A = Args.getLastArg(OPT_enable_deserialization_recovery,
OPT_disable_deserialization_recovery)) {
Expand Down
16 changes: 11 additions & 5 deletions lib/Parse/ParseDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6503,11 +6503,12 @@ Parser::parseDeclOperatorImpl(SourceLoc OperatorLoc, Identifier Name,
if (Tok.is(tok::colon)) {
SyntaxParsingContext GroupCtxt(SyntaxContext, SyntaxKind::InfixOperatorGroup);
colonLoc = consumeToken();
if (Tok.is(tok::identifier)) {
firstIdentifierName = Context.getIdentifier(Tok.getText());
firstIdentifierNameLoc = consumeToken(tok::identifier);

if (Context.LangOpts.EnableOperatorDesignatedProtocols) {
if (Context.LangOpts.EnableOperatorDesignatedTypes) {
if (Tok.is(tok::identifier)) {
firstIdentifierName = Context.getIdentifier(Tok.getText());
firstIdentifierNameLoc = consumeToken(tok::identifier);

if (consumeIf(tok::comma)) {
if (isPrefix || isPostfix)
diagnose(colonLoc, diag::precedencegroup_not_infix)
Expand All @@ -6521,7 +6522,12 @@ Parser::parseDeclOperatorImpl(SourceLoc OperatorLoc, Identifier Name,
diagnose(otherTokLoc, diag::operator_decl_trailing_comma);
}
}
} else if (isPrefix || isPostfix) {
}
} else if (Tok.is(tok::identifier)) {
firstIdentifierName = Context.getIdentifier(Tok.getText());
firstIdentifierNameLoc = consumeToken(tok::identifier);

if (isPrefix || isPostfix) {
diagnose(colonLoc, diag::precedencegroup_not_infix)
.fixItRemove({colonLoc, firstIdentifierNameLoc});
}
Expand Down
15 changes: 8 additions & 7 deletions lib/Sema/CSSolver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1629,11 +1629,12 @@ static bool isOperatorBindOverload(Constraint *bindOverload) {
// Given a bind overload constraint for an operator, return the
// protocol designated as the first place to look for overloads of the
// operator.
static ProtocolDecl *getOperatorDesignatedProtocol(Constraint *bindOverload) {
static NominalTypeDecl *
getOperatorDesignatedNominalType(Constraint *bindOverload) {
auto choice = bindOverload->getOverloadChoice();
auto *funcDecl = cast<FuncDecl>(choice.getDecl());
auto *operatorDecl = funcDecl->getOperatorDecl();
return operatorDecl->getDesignatedProtocol();
return operatorDecl->getDesignatedNominalType();
}

void ConstraintSystem::partitionDisjunction(
Expand All @@ -1649,7 +1650,7 @@ void ConstraintSystem::partitionDisjunction(
};

if (!getASTContext().isSwiftVersionAtLeast(5) ||
!TC.getLangOpts().SolverEnableOperatorDesignatedProtocols ||
!TC.getLangOpts().SolverEnableOperatorDesignatedTypes ||
!isOperatorBindOverload(Choices[0])) {
originalOrdering();
return;
Expand Down Expand Up @@ -1720,21 +1721,21 @@ void ConstraintSystem::partitionDisjunction(

// Now collect the overload choices that are defined within the type
// that was designated in the operator declaration.
auto *designatedProtocol = getOperatorDesignatedProtocol(Choices[0]);
if (designatedProtocol) {
auto *designatedNominal = getOperatorDesignatedNominalType(Choices[0]);
if (designatedNominal) {
forEachChoice(Choices, [&](unsigned index, Constraint *constraint) -> bool {
auto *decl = constraint->getOverloadChoice().getDecl();
auto *funcDecl = cast<FuncDecl>(decl);

auto *parentDecl = funcDecl->getParent()->getAsDecl();
if (parentDecl == designatedProtocol) {
if (parentDecl == designatedNominal) {
definedInDesignatedType.push_back(index);
return true;
}

if (auto *extensionDecl = dyn_cast<ExtensionDecl>(parentDecl)) {
parentDecl = extensionDecl->getExtendedNominal();
if (parentDecl == designatedProtocol) {
if (parentDecl == designatedNominal) {
definedInExtensionOfDesignatedType.push_back(index);
return true;
}
Expand Down
51 changes: 22 additions & 29 deletions lib/Sema/TypeCheckDecl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2061,9 +2061,9 @@ PrecedenceGroupDecl *TypeChecker::lookupPrecedenceGroup(DeclContext *dc,
return group;
}

static void checkDesignatedProtocol(OperatorDecl *OD, Identifier name,
SourceLoc loc, TypeChecker &tc,
ASTContext &ctx) {
static void checkDesignatedTypes(OperatorDecl *OD, Identifier name,
SourceLoc loc, TypeChecker &tc,
ASTContext &ctx) {
auto *dc = OD->getDeclContext();
auto *TyR = new (ctx) SimpleIdentTypeRepr(loc, name);
TypeLoc typeLoc = TypeLoc(TyR);
Expand All @@ -2074,17 +2074,11 @@ static void checkDesignatedProtocol(OperatorDecl *OD, Identifier name,
}

if (!typeLoc.isError()) {
auto *decl = typeLoc.getType()->getNominalOrBoundGenericNominal();
if (!decl || !isa<ProtocolDecl>(decl)) {
tc.diagnose(typeLoc.getLoc(),
diag::operators_designated_protocol_not_a_protocol,
typeLoc.getType());
OD->setInvalid();
} else {
OD->setDesignatedProtocol(cast<ProtocolDecl>(decl));
// FIXME: verify this operator has a declaration within this
// protocol with the same arity and fixity
}
auto *decl = typeLoc.getType()->getAnyNominal();
assert(decl);
OD->setDesignatedNominalType(decl);
// FIXME: verify this operator has a declaration within this
// protocol with the same arity and fixity
}
}

Expand All @@ -2099,17 +2093,16 @@ void TypeChecker::validateDecl(OperatorDecl *OD) {

auto IOD = dyn_cast<InfixOperatorDecl>(OD);

auto enableOperatorDesignatedProtocols =
getLangOpts().EnableOperatorDesignatedProtocols;
auto enableOperatorDesignatedTypes =
getLangOpts().EnableOperatorDesignatedTypes;

// Pre- or post-fix operator?
if (!IOD) {
auto *protocol = OD->getDesignatedProtocol();
auto protocolId = OD->getDesignatedProtocolName();
if (!protocol && !protocolId.empty() &&
enableOperatorDesignatedProtocols) {
auto protocolIdLoc = OD->getDesignatedProtocolNameLoc();
checkDesignatedProtocol(OD, protocolId, protocolIdLoc, *this, Context);
auto *nominal = OD->getDesignatedNominalType();
auto nominalId = OD->getDesignatedNominalTypeName();
if (!nominal && !nominalId.empty() && enableOperatorDesignatedTypes) {
auto nominalIdLoc = OD->getDesignatedNominalTypeNameLoc();
checkDesignatedTypes(OD, nominalId, nominalIdLoc, *this, Context);
}
return;
}
Expand All @@ -2127,20 +2120,20 @@ void TypeChecker::validateDecl(OperatorDecl *OD) {
}

auto secondId = IOD->getSecondIdentifier();
auto *protocol = IOD->getDesignatedProtocol();
if (!protocol && enableOperatorDesignatedProtocols) {
auto *nominal = IOD->getDesignatedNominalType();
if (!nominal && enableOperatorDesignatedTypes) {
auto secondIdLoc = IOD->getSecondIdentifierLoc();
assert(secondId.empty() || !firstId.empty());

auto protocolId = group ? secondId : firstId;
auto protocolIdLoc = group ? secondIdLoc : firstIdLoc;
if (!protocolId.empty())
checkDesignatedProtocol(IOD, protocolId, protocolIdLoc, *this, Context);
auto nominalId = group ? secondId : firstId;
auto nominalIdLoc = group ? secondIdLoc : firstIdLoc;
if (!nominalId.empty())
checkDesignatedTypes(IOD, nominalId, nominalIdLoc, *this, Context);
}

if (!group && !IOD->isInvalid()) {
if (!firstId.empty() &&
(!secondId.empty() || !IOD->getDesignatedProtocol())) {
(!secondId.empty() || !IOD->getDesignatedNominalType())) {
diagnose(firstIdLoc, diag::unknown_precedence_group, firstId);
IOD->setInvalid();
}
Expand Down
Loading