Skip to content

Commit e8ec3e2

Browse files
committed
[HLSL] Implement output parameter
HLSL output parameters are denoted with the `inout` and `out` keywords in the function declaration. When an argument to an output parameter is constructed a temporary value is constructed for the argument. For `inout` pamameters the argument is intialized by casting the argument expression to the parameter type. For `out` parameters the argument is not initialized before the call. In both cases on return of the function the temporary value is written back to the argument lvalue expression through an optional casting sequence if required. This change introduces a new HLSLOutArgExpr ast node which represents the output argument behavior. The OutArgExpr has two defined children: the base expresion and the writeback expression. The writeback expression will either be or contain an OpaqueValueExpr child expression which is used during code generation to represent the temporary value.
1 parent fb70282 commit e8ec3e2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+854
-45
lines changed

clang/include/clang/AST/ASTContext.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,6 +1377,14 @@ class ASTContext : public RefCountedBase<ASTContext> {
13771377
/// in the return type and parameter types.
13781378
bool hasSameFunctionTypeIgnoringPtrSizes(QualType T, QualType U);
13791379

1380+
/// Get or construct a function type that is equivalent to the input type
1381+
/// except that the parameter ABI annotations are stripped.
1382+
QualType getFunctionTypeWithoutParamABIs(QualType T);
1383+
1384+
/// Determine if two function types are the same, ignoring parameter ABI
1385+
/// annotations.
1386+
bool hasSameFunctionTypeIgnoringParamABI(QualType T, QualType U);
1387+
13801388
/// Return the uniqued reference to the type for a complex
13811389
/// number with the specified element type.
13821390
QualType getComplexType(QualType T) const;

clang/include/clang/AST/Attr.h

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -224,20 +224,7 @@ class ParameterABIAttr : public InheritableParamAttr {
224224
InheritEvenIfAlreadyPresent) {}
225225

226226
public:
227-
ParameterABI getABI() const {
228-
switch (getKind()) {
229-
case attr::SwiftContext:
230-
return ParameterABI::SwiftContext;
231-
case attr::SwiftAsyncContext:
232-
return ParameterABI::SwiftAsyncContext;
233-
case attr::SwiftErrorResult:
234-
return ParameterABI::SwiftErrorResult;
235-
case attr::SwiftIndirectResult:
236-
return ParameterABI::SwiftIndirectResult;
237-
default:
238-
llvm_unreachable("bad parameter ABI attribute kind");
239-
}
240-
}
227+
ParameterABI getABI() const;
241228

242229
static bool classof(const Attr *A) {
243230
return A->getKind() >= attr::FirstParameterABIAttr &&
@@ -379,6 +366,29 @@ inline const StreamingDiagnostic &operator<<(const StreamingDiagnostic &DB,
379366
DB.AddTaggedVal(reinterpret_cast<uint64_t>(At), DiagnosticsEngine::ak_attr);
380367
return DB;
381368
}
369+
370+
inline ParameterABI ParameterABIAttr::getABI() const {
371+
switch (getKind()) {
372+
case attr::SwiftContext:
373+
return ParameterABI::SwiftContext;
374+
case attr::SwiftAsyncContext:
375+
return ParameterABI::SwiftAsyncContext;
376+
case attr::SwiftErrorResult:
377+
return ParameterABI::SwiftErrorResult;
378+
case attr::SwiftIndirectResult:
379+
return ParameterABI::SwiftIndirectResult;
380+
case attr::HLSLParamModifier: {
381+
const auto *A = cast<HLSLParamModifierAttr>(this);
382+
if (A->isOut())
383+
return ParameterABI::HLSLOut;
384+
if (A->isInOut())
385+
return ParameterABI::HLSLInOut;
386+
return ParameterABI::Ordinary;
387+
}
388+
default:
389+
llvm_unreachable("bad parameter ABI attribute kind");
390+
}
391+
}
382392
} // end namespace clang
383393

384394
#endif

clang/include/clang/AST/Expr.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7061,6 +7061,67 @@ class ArraySectionExpr : public Expr {
70617061
void setRBracketLoc(SourceLocation L) { RBracketLoc = L; }
70627062
};
70637063

7064+
/// This class represents temporary values used to represent inout and out
7065+
/// arguments in HLSL. From the callee perspective these parameters are more or
7066+
/// less __restrict__ T&. They are guaranteed to not alias any memory. inout
7067+
/// parameters are initialized by the caller, and out parameters are references
7068+
/// to uninitialized memory.
7069+
///
7070+
/// In the caller, the argument expression creates a temporary in local memory
7071+
/// and the address of the temporary is passed into the callee. There may be
7072+
/// implicit conversion sequences to initialize the temporary, and on expiration
7073+
/// of the temporary an inverse conversion sequence is applied as a write-back
7074+
/// conversion to the source l-value.
7075+
class HLSLOutArgExpr : public Expr {
7076+
friend class ASTStmtReader;
7077+
7078+
Expr *Base;
7079+
Expr *Writeback;
7080+
OpaqueValueExpr *OpaqueVal;
7081+
bool IsInOut;
7082+
7083+
HLSLOutArgExpr(QualType Ty, Expr *B, Expr *WB, OpaqueValueExpr *OpV,
7084+
bool IsInOut)
7085+
: Expr(HLSLOutArgExprClass, Ty, VK_LValue, OK_Ordinary), Base(B),
7086+
Writeback(WB), OpaqueVal(OpV), IsInOut(IsInOut) {
7087+
assert(!Ty->isDependentType() && "HLSLOutArgExpr given a dependent type!");
7088+
}
7089+
7090+
explicit HLSLOutArgExpr(EmptyShell Shell)
7091+
: Expr(HLSLOutArgExprClass, Shell) {}
7092+
7093+
public:
7094+
static HLSLOutArgExpr *Create(const ASTContext &C, QualType Ty, Expr *Base,
7095+
bool IsInOut, Expr *WB, OpaqueValueExpr *OpV);
7096+
static HLSLOutArgExpr *CreateEmpty(const ASTContext &Ctx);
7097+
7098+
const Expr *getBase() const { return Base; }
7099+
Expr *getBase() { return Base; }
7100+
7101+
const Expr *getWriteback() const { return Writeback; }
7102+
Expr *getWriteback() { return Writeback; }
7103+
7104+
const OpaqueValueExpr *getOpaqueValue() const { return OpaqueVal; }
7105+
OpaqueValueExpr *getOpaqueValue() { return OpaqueVal; }
7106+
7107+
bool isInOut() const { return IsInOut; }
7108+
7109+
SourceLocation getBeginLoc() const LLVM_READONLY {
7110+
return Base->getBeginLoc();
7111+
}
7112+
7113+
SourceLocation getEndLoc() const LLVM_READONLY { return Base->getEndLoc(); }
7114+
7115+
static bool classof(const Stmt *T) {
7116+
return T->getStmtClass() == HLSLOutArgExprClass;
7117+
}
7118+
7119+
// Iterators
7120+
child_range children() {
7121+
return child_range((Stmt **)&Base, ((Stmt **)&Writeback) + 1);
7122+
}
7123+
};
7124+
70647125
/// Frontend produces RecoveryExprs on semantic errors that prevent creating
70657126
/// other well-formed expressions. E.g. when type-checking of a binary operator
70667127
/// fails, we cannot produce a BinaryOperator expression. Instead, we can choose

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4014,6 +4014,9 @@ DEF_TRAVERSE_STMT(OpenACCComputeConstruct,
40144014
DEF_TRAVERSE_STMT(OpenACCLoopConstruct,
40154015
{ TRY_TO(TraverseOpenACCAssociatedStmtConstruct(S)); })
40164016

4017+
// Traverse HLSL: Out argument expression
4018+
DEF_TRAVERSE_STMT(HLSLOutArgExpr, {})
4019+
40174020
// FIXME: look at the following tricky-seeming exprs to see if we
40184021
// need to recurse on anything. These are ones that have methods
40194022
// returning decls or qualtypes or nestednamespecifier -- though I'm

clang/include/clang/AST/TextNodeDumper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ class TextNodeDumper
407407
void
408408
VisitLifetimeExtendedTemporaryDecl(const LifetimeExtendedTemporaryDecl *D);
409409
void VisitHLSLBufferDecl(const HLSLBufferDecl *D);
410+
void VisitHLSLOutArgExpr(const HLSLOutArgExpr *E);
410411
void VisitOpenACCConstructStmt(const OpenACCConstructStmt *S);
411412
void VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S);
412413
void VisitEmbedExpr(const EmbedExpr *S);

clang/include/clang/Basic/Attr.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4613,14 +4613,13 @@ def HLSLGroupSharedAddressSpace : TypeAttr {
46134613
let Documentation = [HLSLGroupSharedAddressSpaceDocs];
46144614
}
46154615

4616-
def HLSLParamModifier : TypeAttr {
4616+
def HLSLParamModifier : ParameterABIAttr {
46174617
let Spellings = [CustomKeyword<"in">, CustomKeyword<"inout">, CustomKeyword<"out">];
46184618
let Accessors = [Accessor<"isIn", [CustomKeyword<"in">]>,
46194619
Accessor<"isInOut", [CustomKeyword<"inout">]>,
46204620
Accessor<"isOut", [CustomKeyword<"out">]>,
46214621
Accessor<"isAnyOut", [CustomKeyword<"out">, CustomKeyword<"inout">]>,
46224622
Accessor<"isAnyIn", [CustomKeyword<"in">, CustomKeyword<"inout">]>];
4623-
let Subjects = SubjectList<[ParmVar]>;
46244623
let Documentation = [HLSLParamQualifierDocs];
46254624
let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>];
46264625
}

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12357,6 +12357,8 @@ def warn_hlsl_availability : Warning<
1235712357
def warn_hlsl_availability_unavailable :
1235812358
Warning<err_unavailable.Summary>,
1235912359
InGroup<HLSLAvailability>, DefaultError;
12360+
def error_hlsl_inout_scalar_extension : Error<"illegal scalar extension cast on argument %0 to %select{|in}1out paramemter">;
12361+
def error_hlsl_inout_lvalue : Error<"cannot bind non-lvalue argument %0 to %select{|in}1out paramemter">;
1236012362

1236112363
def err_hlsl_export_not_on_function : Error<
1236212364
"export declaration can only be used on functions">;

clang/include/clang/Basic/Specifiers.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,12 @@ namespace clang {
382382
/// Swift asynchronous context-pointer ABI treatment. There can be at
383383
/// most one parameter on a given function that uses this treatment.
384384
SwiftAsyncContext,
385+
386+
// This parameter is a copy-out HLSL parameter.
387+
HLSLOut,
388+
389+
// This parameter is a copy-in/copy-out HLSL parameter.
390+
HLSLInOut,
385391
};
386392

387393
/// Assigned inheritance model for a class in the MS C++ ABI. Must match order

clang/include/clang/Basic/StmtNodes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,3 +306,6 @@ def OpenACCAssociatedStmtConstruct
306306
: StmtNode<OpenACCConstructStmt, /*abstract=*/1>;
307307
def OpenACCComputeConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
308308
def OpenACCLoopConstruct : StmtNode<OpenACCAssociatedStmtConstruct>;
309+
310+
// HLSL Constructs.
311+
def HLSLOutArgExpr : StmtNode<Expr>;

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class SemaHLSL : public SemaBase {
6161
void handleParamModifierAttr(Decl *D, const ParsedAttr &AL);
6262

6363
bool CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall);
64+
65+
ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg);
6466
};
6567

6668
} // namespace clang

clang/include/clang/Serialization/ASTBitCodes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,6 +1988,9 @@ enum StmtCode {
19881988
// OpenACC Constructs
19891989
STMT_OPENACC_COMPUTE_CONSTRUCT,
19901990
STMT_OPENACC_LOOP_CONSTRUCT,
1991+
1992+
// HLSL Constructs
1993+
EXPR_HLSL_OUT_ARG,
19911994
};
19921995

19931996
/// The kinds of designators that can occur in a

clang/lib/AST/ASTContext.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3590,6 +3590,20 @@ bool ASTContext::hasSameFunctionTypeIgnoringPtrSizes(QualType T, QualType U) {
35903590
getFunctionTypeWithoutPtrSizes(U));
35913591
}
35923592

3593+
QualType ASTContext::getFunctionTypeWithoutParamABIs(QualType T) {
3594+
if (const auto *Proto = T->getAs<FunctionProtoType>()) {
3595+
FunctionProtoType::ExtProtoInfo EPI = Proto->getExtProtoInfo();
3596+
EPI.ExtParameterInfos = nullptr;
3597+
return getFunctionType(Proto->getReturnType(), Proto->param_types(), EPI);
3598+
}
3599+
return T;
3600+
}
3601+
3602+
bool ASTContext::hasSameFunctionTypeIgnoringParamABI(QualType T, QualType U) {
3603+
return hasSameType(T, U) || hasSameType(getFunctionTypeWithoutParamABIs(T),
3604+
getFunctionTypeWithoutParamABIs(U));
3605+
}
3606+
35933607
void ASTContext::adjustExceptionSpec(
35943608
FunctionDecl *FD, const FunctionProtoType::ExceptionSpecInfo &ESI,
35953609
bool AsWritten) {

clang/lib/AST/Expr.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3631,6 +3631,7 @@ bool Expr::HasSideEffects(const ASTContext &Ctx,
36313631
case RequiresExprClass:
36323632
case SYCLUniqueStableNameExprClass:
36333633
case PackIndexingExprClass:
3634+
case HLSLOutArgExprClass:
36343635
// These never have a side-effect.
36353636
return false;
36363637

@@ -5318,3 +5319,13 @@ OMPIteratorExpr *OMPIteratorExpr::CreateEmpty(const ASTContext &Context,
53185319
alignof(OMPIteratorExpr));
53195320
return new (Mem) OMPIteratorExpr(EmptyShell(), NumIterators);
53205321
}
5322+
5323+
HLSLOutArgExpr *HLSLOutArgExpr::Create(const ASTContext &C, QualType Ty,
5324+
Expr *Base, bool IsInOut, Expr *WB,
5325+
OpaqueValueExpr *OpV) {
5326+
return new (C) HLSLOutArgExpr(Ty, Base, WB, OpV, IsInOut);
5327+
}
5328+
5329+
HLSLOutArgExpr *HLSLOutArgExpr::CreateEmpty(const ASTContext &C) {
5330+
return new (C) HLSLOutArgExpr(EmptyShell());
5331+
}

clang/lib/AST/ExprClassification.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ static Cl::Kinds ClassifyInternal(ASTContext &Ctx, const Expr *E) {
148148
case Expr::ArraySectionExprClass:
149149
case Expr::OMPArrayShapingExprClass:
150150
case Expr::OMPIteratorExprClass:
151+
case Expr::HLSLOutArgExprClass:
151152
return Cl::CL_LValue;
152153

153154
// C99 6.5.2.5p5 says that compound literals are lvalues.

clang/lib/AST/ExprConstant.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16469,6 +16469,7 @@ static ICEDiag CheckICE(const Expr* E, const ASTContext &Ctx) {
1646916469
case Expr::CoyieldExprClass:
1647016470
case Expr::SYCLUniqueStableNameExprClass:
1647116471
case Expr::CXXParenListInitExprClass:
16472+
case Expr::HLSLOutArgExprClass:
1647216473
return ICEDiag(IK_NotICE, E->getBeginLoc());
1647316474

1647416475
case Expr::InitListExprClass: {

clang/lib/AST/ItaniumMangle.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3507,6 +3507,12 @@ CXXNameMangler::mangleExtParameterInfo(FunctionProtoType::ExtParameterInfo PI) {
35073507
case ParameterABI::Ordinary:
35083508
break;
35093509

3510+
// HLSL parameter mangling.
3511+
case ParameterABI::HLSLOut:
3512+
case ParameterABI::HLSLInOut:
3513+
mangleVendorQualifier(getParameterABISpelling(PI.getABI()));
3514+
break;
3515+
35103516
// All of these start with "swift", so they come before "ns_consumed".
35113517
case ParameterABI::SwiftContext:
35123518
case ParameterABI::SwiftAsyncContext:
@@ -5703,6 +5709,12 @@ void CXXNameMangler::mangleExpression(const Expr *E, unsigned Arity,
57035709
Out << "E";
57045710
break;
57055711
}
5712+
case Expr::HLSLOutArgExprClass: {
5713+
const auto *OAE = cast<clang::HLSLOutArgExpr>(E);
5714+
Out << (OAE->isInOut() ? "_inout_" : "_out_");
5715+
mangleType(E->getType());
5716+
break;
5717+
}
57065718
}
57075719

57085720
if (AsTemplateArg && !IsPrimaryExpr)

clang/lib/AST/StmtPrinter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2799,6 +2799,10 @@ void StmtPrinter::VisitAsTypeExpr(AsTypeExpr *Node) {
27992799
OS << ")";
28002800
}
28012801

2802+
void StmtPrinter::VisitHLSLOutArgExpr(HLSLOutArgExpr *Node) {
2803+
PrintExpr(Node->getBase());
2804+
}
2805+
28022806
//===----------------------------------------------------------------------===//
28032807
// Stmt method implementations
28042808
//===----------------------------------------------------------------------===//

clang/lib/AST/StmtProfile.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,6 +2631,10 @@ void StmtProfiler::VisitOpenACCLoopConstruct(const OpenACCLoopConstruct *S) {
26312631
P.VisitOpenACCClauseList(S->clauses());
26322632
}
26332633

2634+
void StmtProfiler::VisitHLSLOutArgExpr(const HLSLOutArgExpr *S) {
2635+
VisitStmt(S);
2636+
}
2637+
26342638
void Stmt::Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
26352639
bool Canonical, bool ProfileLambdaExpr) const {
26362640
StmtProfilerWithPointers Profiler(ID, Context, Canonical, ProfileLambdaExpr);

clang/lib/AST/TextNodeDumper.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2874,6 +2874,10 @@ void TextNodeDumper::VisitHLSLBufferDecl(const HLSLBufferDecl *D) {
28742874
dumpName(D);
28752875
}
28762876

2877+
void TextNodeDumper::VisitHLSLOutArgExpr(const HLSLOutArgExpr *E) {
2878+
OS << (E->isInOut() ? " inout" : " out");
2879+
}
2880+
28772881
void TextNodeDumper::VisitOpenACCConstructStmt(const OpenACCConstructStmt *S) {
28782882
OS << " " << S->getDirectiveKind();
28792883
}

clang/lib/AST/TypePrinter.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,10 @@ StringRef clang::getParameterABISpelling(ParameterABI ABI) {
933933
return "swift_error_result";
934934
case ParameterABI::SwiftIndirectResult:
935935
return "swift_indirect_result";
936+
case ParameterABI::HLSLOut:
937+
return "out";
938+
case ParameterABI::HLSLInOut:
939+
return "inout";
936940
}
937941
llvm_unreachable("bad parameter ABI kind");
938942
}
@@ -955,7 +959,17 @@ void TypePrinter::printFunctionProtoAfter(const FunctionProtoType *T,
955959
if (EPI.isNoEscape())
956960
OS << "__attribute__((noescape)) ";
957961
auto ABI = EPI.getABI();
958-
if (ABI != ParameterABI::Ordinary)
962+
if (ABI == ParameterABI::HLSLInOut || ABI == ParameterABI::HLSLOut) {
963+
OS << getParameterABISpelling(ABI) << " ";
964+
if (Policy.UseHLSLTypes) {
965+
// This is a bit of a hack because we _do_ use reference types in the
966+
// AST for representing inout and out parameters so that code
967+
// generation is sane, but when re-printing these for HLSL we need to
968+
// skip the reference.
969+
print(T->getParamType(i).getNonReferenceType(), OS, StringRef());
970+
continue;
971+
}
972+
} else if (ABI != ParameterABI::Ordinary)
959973
OS << "__attribute__((" << getParameterABISpelling(ABI) << ")) ";
960974

961975
print(T->getParamType(i), OS, StringRef());
@@ -2023,10 +2037,6 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
20232037
case attr::ArmMveStrictPolymorphism:
20242038
OS << "__clang_arm_mve_strict_polymorphism";
20252039
break;
2026-
2027-
// Nothing to print for this attribute.
2028-
case attr::HLSLParamModifier:
2029-
break;
20302040
}
20312041
OS << "))";
20322042
}

0 commit comments

Comments
 (0)