Skip to content

[HLSL] Implement output parameter #101083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e8ec3e2
[HLSL] Implement output parameter
llvm-beanz Jul 29, 2024
ac6799f
Fix aggregate copy emission
llvm-beanz Jul 29, 2024
d865dbf
Move parameter ABI qualifier mismatch
llvm-beanz Jul 30, 2024
6c8c058
Fix clang-format
llvm-beanz Jul 30, 2024
3441e91
Limit when function declarations are merged
llvm-beanz Jul 30, 2024
2283f0a
Resolve FIXME/TODO notes
llvm-beanz Jul 31, 2024
6b66117
One last fix to correctly annotate parameter types
llvm-beanz Jul 31, 2024
7acfe26
Update clang/test/SemaHLSL/Language/OutputParameters.hlsl
llvm-beanz Aug 1, 2024
c094804
Fix include ordering
llvm-beanz Aug 5, 2024
1f1f398
Merge remote-tracking branch 'origin/main' into cbieneman/inout
llvm-beanz Aug 19, 2024
adb4d3c
Update for upstream changes
llvm-beanz Aug 19, 2024
a529de0
Update based on PR feedback
llvm-beanz Aug 26, 2024
5f7e351
Fixed ActOnBinOp
llvm-beanz Aug 27, 2024
80c4c8d
Merge remote-tracking branch 'origin/main' into cbieneman/inout
llvm-beanz Aug 27, 2024
7ab1395
Updates based on latest feedback
llvm-beanz Aug 27, 2024
1bb6cae
clang-format
llvm-beanz Aug 28, 2024
7b21187
clang-format
llvm-beanz Aug 28, 2024
54ebfa0
Merge remote-tracking branch 'origin/main' into cbieneman/inout
llvm-beanz Aug 28, 2024
821c565
Add an assert in SemaHLSL::getInoutParameterType
llvm-beanz Aug 29, 2024
4c2be5c
Process the initial lvalue in SemaChecking
llvm-beanz Aug 29, 2024
b883f66
Avoid double mutating parameter modifiers
llvm-beanz Aug 29, 2024
55804b1
Update clang/include/clang/AST/Expr.h
llvm-beanz Aug 30, 2024
57dfb16
Update clang/include/clang/AST/Expr.h
llvm-beanz Aug 30, 2024
f407786
Update code comment
llvm-beanz Aug 31, 2024
e39cd29
Remove unreachable case
llvm-beanz Aug 31, 2024
92aa4c0
Rename CastExpr to WritebackExpr
llvm-beanz Aug 31, 2024
3c7a96f
Extract out common helper code
llvm-beanz Aug 31, 2024
bd003fd
Merge remote-tracking branch 'origin/main' into cbieneman/inout
llvm-beanz Aug 31, 2024
ddb0711
Add comment
llvm-beanz Aug 31, 2024
6e5c7ce
Merge remote-tracking branch 'origin/main' into cbieneman/inout
llvm-beanz Aug 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions clang/include/clang/AST/Attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,20 +224,7 @@ class ParameterABIAttr : public InheritableParamAttr {
InheritEvenIfAlreadyPresent) {}

public:
ParameterABI getABI() const {
switch (getKind()) {
case attr::SwiftContext:
return ParameterABI::SwiftContext;
case attr::SwiftAsyncContext:
return ParameterABI::SwiftAsyncContext;
case attr::SwiftErrorResult:
return ParameterABI::SwiftErrorResult;
case attr::SwiftIndirectResult:
return ParameterABI::SwiftIndirectResult;
default:
llvm_unreachable("bad parameter ABI attribute kind");
}
}
ParameterABI getABI() const;

static bool classof(const Attr *A) {
return A->getKind() >= attr::FirstParameterABIAttr &&
Expand Down Expand Up @@ -379,6 +366,29 @@ inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &DB,
DB.AddTaggedVal(reinterpret_cast<uint64_t>(At), DiagnosticsEngine::ak_attr);
return DB;
}

inline ParameterABI ParameterABIAttr::getABI() const {
switch (getKind()) {
case attr::SwiftContext:
return ParameterABI::SwiftContext;
case attr::SwiftAsyncContext:
return ParameterABI::SwiftAsyncContext;
case attr::SwiftErrorResult:
return ParameterABI::SwiftErrorResult;
case attr::SwiftIndirectResult:
return ParameterABI::SwiftIndirectResult;
case attr::HLSLParamModifier: {
const auto *A = cast<HLSLParamModifierAttr>(this);
if (A->isOut())
return ParameterABI::HLSLOut;
if (A->isInOut())
return ParameterABI::HLSLInOut;
return ParameterABI::Ordinary;
}
default:
llvm_unreachable("bad parameter ABI attribute kind");
}
}
} // end namespace clang

#endif
61 changes: 61 additions & 0 deletions clang/include/clang/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -7061,6 +7061,67 @@ class ArraySectionExpr : public Expr {
void setRBracketLoc(SourceLocation L) { RBracketLoc = L; }
};

/// This class represents temporary values used to represent inout and out
/// arguments in HLSL. From the callee perspective these parameters are more or
/// less __restrict__ T&. They are guaranteed to not alias any memory. inout
/// parameters are initialized by the caller, and out parameters are references
/// to uninitialized memory.
///
/// In the caller, the argument expression creates a temporary in local memory
/// and the address of the temporary is passed into the callee. There may be
/// implicit conversion sequences to initialize the temporary, and on expiration
/// of the temporary an inverse conversion sequence is applied as a write-back
/// conversion to the source l-value.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somewhere in this declaration — either here or on the getter/setter pairs below — you need to actually document the structure of the sub-expressions. That's needed whenever the sub-expressions don't obviously reflect the structure of the source code, which at minimum is true whenever an OVE is involved.

class HLSLOutArgExpr : public Expr {
friend class ASTStmtReader;

Expr *Base;
Expr *Writeback;
OpaqueValueExpr *OpaqueVal;
bool IsInOut;

HLSLOutArgExpr(QualType Ty, Expr *B, Expr *WB, OpaqueValueExpr *OpV,
bool IsInOut)
: Expr(HLSLOutArgExprClass, Ty, VK_LValue, OK_Ordinary), Base(B),
Writeback(WB), OpaqueVal(OpV), IsInOut(IsInOut) {
assert(!Ty->isDependentType() && "HLSLOutArgExpr given a dependent type!");
}

explicit HLSLOutArgExpr(EmptyShell Shell)
: Expr(HLSLOutArgExprClass, Shell) {}

public:
static HLSLOutArgExpr *Create(const ASTContext &C, QualType Ty, Expr *Base,
bool IsInOut, Expr *WB, OpaqueValueExpr *OpV);
static HLSLOutArgExpr *CreateEmpty(const ASTContext &Ctx);

const Expr *getBase() const { return Base; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to document what the expressions are here, because it's pretty subtle. So the way you've written it, this a conversion of the base l-value to the type of the parameter, the writeback is an assignment of the temporary result back to the base l-value, and the OVE represents the temporary result. I think this is problematic, though; see the comment in Sema.

Expr *getBase() { return Base; }

const Expr *getWriteback() const { return Writeback; }
Expr *getWriteback() { return Writeback; }

const OpaqueValueExpr *getOpaqueValue() const { return OpaqueVal; }
OpaqueValueExpr *getOpaqueValue() { return OpaqueVal; }

bool isInOut() const { return IsInOut; }

SourceLocation getBeginLoc() const LLVM_READONLY {
return Base->getBeginLoc();
}

SourceLocation getEndLoc() const LLVM_READONLY { return Base->getEndLoc(); }

static bool classof(const Stmt *T) {
return T->getStmtClass() == HLSLOutArgExprClass;
}

// Iterators
child_range children() {
return child_range((Stmt **)&Base, ((Stmt **)&Writeback) + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I certainly hope we don't do this elsewhere. Please declare an array field and make the accessors pull the elements out.

}
};

/// Frontend produces RecoveryExprs on semantic errors that prevent creating
/// other well-formed expressions. E.g. when type-checking of a binary operator
/// fails, we cannot produce a BinaryOperator expression. Instead, we can choose
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/AST/RecursiveASTVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -4014,6 +4014,9 @@ DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
DEF_TRAVERSE_STMT(OpenACCLoopConstruct,
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })

// Traverse HLSL: Out argument expression
DEF_TRAVERSE_STMT(HLSLOutArgExpr, {})

// FIXME: look at the following tricky-seeming exprs to see if we
// need to recurse on anything. These are ones that have methods
// returning decls or qualtypes or nestednamespecifier -- though I'm
Expand Down
1 change: 1 addition & 0 deletions clang/include/clang/AST/TextNodeDumper.h
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class TextNodeDumper
void
VisitLifetimeExtendedTemporaryDecl(const LifetimeExtendedTemporaryDecl *D);
void VisitHLSLBufferDecl(const HLSLBufferDecl *D);
void VisitHLSLOutArgExpr(const HLSLOutArgExpr *E);
void VisitOpenACCConstructStmt(const OpenACCConstructStmt *S);
void VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S);
void VisitEmbedExpr(const EmbedExpr *S);
Expand Down
3 changes: 1 addition & 2 deletions clang/include/clang/Basic/Attr.td
Original file line number Diff line number Diff line change
Expand Up @@ -4613,14 +4613,13 @@ def HLSLGroupSharedAddressSpace : TypeAttr {
let Documentation = [HLSLGroupSharedAddressSpaceDocs];
}

def HLSLParamModifier : TypeAttr {
def HLSLParamModifier : ParameterABIAttr {
let Spellings = [CustomKeyword<"in">, CustomKeyword<"inout">, CustomKeyword<"out">];
let Accessors = [Accessor<"isIn", [CustomKeyword<"in">]>,
Accessor<"isInOut", [CustomKeyword<"inout">]>,
Accessor<"isOut", [CustomKeyword<"out">]>,
Accessor<"isAnyOut", [CustomKeyword<"out">, CustomKeyword<"inout">]>,
Accessor<"isAnyIn", [CustomKeyword<"in">, CustomKeyword<"inout">]>];
let Subjects = SubjectList<[ParmVar]>;
let Documentation = [HLSLParamQualifierDocs];
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
}
Expand Down
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -12357,6 +12357,8 @@ def warn_hlsl_availability : Warning<
def warn_hlsl_availability_unavailable :
Warning<err_unavailable.Summary>,
InGroup<HLSLAvailability>, DefaultError;
def error_hlsl_inout_scalar_extension : Error<"illegal scalar extension cast on argument %0 to %select{|in}1out paramemter">;
def error_hlsl_inout_lvalue : Error<"cannot bind non-lvalue argument %0 to %select{|in}1out paramemter">;

def err_hlsl_export_not_on_function : Error<
"export declaration can only be used on functions">;
Expand Down
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Specifiers.h
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,12 @@ namespace clang {
/// Swift asynchronous context-pointer ABI treatment. There can be at
/// most one parameter on a given function that uses this treatment.
SwiftAsyncContext,

// This parameter is a copy-out HLSL parameter.
HLSLOut,

// This parameter is a copy-in/copy-out HLSL parameter.
HLSLInOut,
};

/// Assigned inheritance model for a class in the MS C++ ABI. Must match order
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/StmtNodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,6 @@ def OpenACCAssociatedStmtConstruct
: StmtNode<OpenACCConstructStmt, /*abstract=*/1>;
def OpenACCComputeConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
def OpenACCLoopConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;

// HLSL Constructs.
def HLSLOutArgExpr : StmtNode<Expr>;
3 changes: 3 additions & 0 deletions clang/include/clang/Sema/SemaHLSL.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class SemaHLSL : public SemaBase {
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL);

bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old);

ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
};

} // namespace clang
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Serialization/ASTBitCodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1988,6 +1988,9 @@ enum StmtCode {
// OpenACC Constructs
STMT_OPENACC_COMPUTE_CONSTRUCT,
STMT_OPENACC_LOOP_CONSTRUCT,

// HLSL Constructs
EXPR_HLSL_OUT_ARG,
};

/// The kinds of designators that can occur in a
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/AST/Expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3631,6 +3631,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
case RequiresExprClass:
case SYCLUniqueStableNameExprClass:
case PackIndexingExprClass:
case HLSLOutArgExprClass:
// These never have a side-effect.
return false;

Expand Down Expand Up @@ -5318,3 +5319,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
alignof(OMPIteratorExpr));
return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
}

HLSLOutArgExpr *HLSLOutArgExpr::Create(const ASTContext &C, QualType Ty,
Expr *Base, bool IsInOut, Expr *WB,
OpaqueValueExpr *OpV) {
return new (C) HLSLOutArgExpr(Ty, Base, WB, OpV, IsInOut);
}

HLSLOutArgExpr *HLSLOutArgExpr::CreateEmpty(const ASTContext &C) {
return new (C) HLSLOutArgExpr(EmptyShell());
}
1 change: 1 addition & 0 deletions clang/lib/AST/ExprClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
case Expr::ArraySectionExprClass:
case Expr::OMPArrayShapingExprClass:
case Expr::OMPIteratorExprClass:
case Expr::HLSLOutArgExprClass:
return Cl::CL_LValue;

// C99 6.5.2.5p5 says that compound literals are lvalues.
Expand Down
1 change: 1 addition & 0 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16469,6 +16469,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
case Expr::CoyieldExprClass:
case Expr::SYCLUniqueStableNameExprClass:
case Expr::CXXParenListInitExprClass:
case Expr::HLSLOutArgExprClass:
return ICEDiag(IK_NotICE, E->getBeginLoc());

case Expr::InitListExprClass: {
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/AST/ItaniumMangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3507,6 +3507,12 @@ CXXNameMangler::mangleExtParameterInfo(FunctionProtoType::ExtParameterInfo PI) {
case ParameterABI::Ordinary:
break;

// HLSL parameter mangling.
case ParameterABI::HLSLOut:
case ParameterABI::HLSLInOut:
mangleVendorQualifier(getParameterABISpelling(PI.getABI()));
break;

// All of these start with "swift", so they come before "ns_consumed".
case ParameterABI::SwiftContext:
case ParameterABI::SwiftAsyncContext:
Expand Down Expand Up @@ -5703,6 +5709,12 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
Out << "E";
break;
}
case Expr::HLSLOutArgExprClass: {
const auto *OAE = cast<clang::HLSLOutArgExpr>(E);
Out << (OAE->isInOut() ? "_inout_" : "_out_");
mangleType(E->getType());
break;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so this is wrong in every way. :)

Expression mangling generally follows the syntax. Since there is no argument-side syntax for passing an argument to an out or inout parameter, I'd say the correct mangling is just to ignore this node and mangle the underlying (syntactic) argument expression, which I assume is fairly easy to find (is that the base expression? whatever doesn't have the OVE still in it).

If it were important to mangle this for some reason, you would want to mangle it as a vendor extended expression, u <source-name> <template-arg>* E.

}

if (AsTemplateArg && !IsPrimaryExpr)
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/StmtPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2799,6 +2799,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
OS << ")";
}

void StmtPrinter::VisitHLSLOutArgExpr(HLSLOutArgExpr *Node) {
PrintExpr(Node->getBase());
}

//===----------------------------------------------------------------------===//
// Stmt method implementations
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/StmtProfile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2631,6 +2631,10 @@ void StmtProfiler::VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S) {
P.VisitOpenACCClauseList(S->clauses());
}

void StmtProfiler::VisitHLSLOutArgExpr(const HLSLOutArgExpr *S) {
VisitStmt(S);
}

void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
bool Canonical, bool ProfileLambdaExpr) const {
StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/AST/TextNodeDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2874,6 +2874,10 @@ void TextNodeDumper::VisitHLSLBufferDecl(const HLSLBufferDecl *D) {
dumpName(D);
}

void TextNodeDumper::VisitHLSLOutArgExpr(const HLSLOutArgExpr *E) {
OS << (E->isInOut() ? " inout" : " out");
}

void TextNodeDumper::VisitOpenACCConstructStmt(const OpenACCConstructStmt *S) {
OS << " " << S->getDirectiveKind();
}
Expand Down
20 changes: 15 additions & 5 deletions clang/lib/AST/TypePrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,10 @@ StringRef clang::getParameterABISpelling(ParameterABI ABI) {
return "swift_error_result";
case ParameterABI::SwiftIndirectResult:
return "swift_indirect_result";
case ParameterABI::HLSLOut:
return "out";
case ParameterABI::HLSLInOut:
return "inout";
}
llvm_unreachable("bad parameter ABI kind");
}
Expand All @@ -955,7 +959,17 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
if (EPI.isNoEscape())
OS << "__attribute__((noescape)) ";
auto ABI = EPI.getABI();
if (ABI != ParameterABI::Ordinary)
if (ABI == ParameterABI::HLSLInOut || ABI == ParameterABI::HLSLOut) {
OS << getParameterABISpelling(ABI) << " ";
if (Policy.UseHLSLTypes) {
// This is a bit of a hack because we _do_ use reference types in the
// AST for representing inout and out parameters so that code
// generation is sane, but when re-printing these for HLSL we need to
// skip the reference.
print(T->getParamType(i).getNonReferenceType(), OS, StringRef());
continue;
}
} else if (ABI != ParameterABI::Ordinary)
OS << "__attribute__((" << getParameterABISpelling(ABI) << ")) ";

print(T->getParamType(i), OS, StringRef());
Expand Down Expand Up @@ -2023,10 +2037,6 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
case attr::ArmMveStrictPolymorphism:
OS << "__clang_arm_mve_strict_polymorphism";
break;

// Nothing to print for this attribute.
case attr::HLSLParamModifier:
break;
}
OS << "))";
}
Expand Down
Loading
Loading